diff --git a/.gitignore b/.gitignore index a52d3738..b3458152 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ /.classpath /.project /.settings/ +/engines/ +/models/ # Log file *.log @@ -26,3 +28,7 @@ # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml hs_err_pid* + +# Idea files +.idea/ +lib/ diff --git a/pom.xml b/pom.xml index 56d1cb69..45168b3b 100644 --- a/pom.xml +++ b/pom.xml @@ -5,13 +5,17 @@ org.scijava pom-scijava - 31.1.0 + 35.1.1 io.github.deepimagej DeepImageJ_ +<<<<<<< HEAD 2.1.16 +======= + 3.0.1 +>>>>>>> development DeepImageJ A user-friendly plugin to run deep learning models in ImageJ. @@ -127,10 +131,7 @@ Universidad Carlos III de Madrid. deploy-to-scijava - 3.2.0 - 1.15.0 5.11.0 - 1.1.6 @@ -140,35 +141,23 @@ Universidad Carlos III de Madrid. + - com.google.protobuf - protobuf-java - ${protobuf.version} - - - org.tensorflow - proto - - - org.tensorflow - libtensorflow - - - org.tensorflow - libtensorflow_jni + io.bioimage + dl-modelrunner + 0.3.10 + org.yaml snakeyaml - - net.imagej - imagej-tensorflow - + net.imagej ij +<<<<<<< HEAD ai.djl.pytorch @@ -208,12 +197,13 @@ Universidad Carlos III de Madrid. org.jetbrains.kotlin kotlin-stdlib +======= +>>>>>>> development junit junit test - diff --git a/src/main/java/DeepImageJ_Build_BundledModel.java b/src/main/java/DeepImageJ_Build_BundledModel.java deleted file mode 100755 index af19939a..00000000 --- a/src/main/java/DeepImageJ_Build_BundledModel.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -import java.io.File; -import java.io.IOException; -import java.lang.reflect.Method; -import java.net.URL; -import java.net.URLClassLoader; -import java.util.Arrays; - -import deepimagej.BuildDialog; -import ij.IJ; -import ij.ImageJ; -import ij.ImagePlus; -import ij.plugin.PlugIn; - -public class DeepImageJ_Build_BundledModel implements PlugIn { - - public void run(String arg) { - - // If there is no models directory inside Fiji folder, create it - String path = IJ.getDirectory("imagej") + File.separator + "models" + File.separator; - - if (!(new File(path).isDirectory())) - new File(path).mkdirs(); - - BuildDialog bd = new BuildDialog(); - bd.showDialog(); - } - - public static void main(String args[]) throws IOException { - ImagePlus imp = IJ.openImage("C:\\Users\\Carlos(tfg)\\Videos\\Fiji.app\\models\\exemplary-image-data\\tribolium.tif"); - if (imp != null) - imp.show(); - new ImageJ(); - BuildDialog bd = new BuildDialog(); - bd.showDialog();/* - if (WindowManager.getCurrentImage() == null) { - IJ.error("There should be an image open."); - } else { - BuildDialog bd = new BuildDialog(); - bd.showDialog(); - }*/ - } -} \ No newline at end of file diff --git a/src/main/java/DeepImageJ_Run.java b/src/main/java/DeepImageJ_Run.java index ca717f24..9d40d048 100755 --- a/src/main/java/DeepImageJ_Run.java +++ b/src/main/java/DeepImageJ_Run.java @@ -58,22 +58,26 @@ import java.awt.event.ItemEvent; import java.awt.event.ItemListener; import java.io.File; +import java.io.IOException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import deepimagej.DeepImageJ; import deepimagej.Constants; -import deepimagej.RunnerTf; import deepimagej.RunnerProgress; -import deepimagej.RunnerPt; +import deepimagej.RunnerDL; import deepimagej.DeepLearningModel; import deepimagej.components.BorderPanel; import deepimagej.exceptions.MacrosError; +import deepimagej.modelrunner.EngineInstaller; import deepimagej.processing.HeadlessProcessing; import deepimagej.tools.ArrayOperations; import deepimagej.tools.DijRunnerPostprocessing; @@ -82,7 +86,6 @@ import deepimagej.tools.Index; import deepimagej.tools.Log; import deepimagej.tools.ModelLoader; -import deepimagej.tools.StartTensorflowService; import deepimagej.tools.SystemUsage; import ij.IJ; import ij.ImagePlus; @@ -90,11 +93,19 @@ import ij.WindowManager; import ij.gui.GenericDialog; import ij.plugin.PlugIn; +import io.bioimage.modelrunner.bioimageio.download.DownloadTracker.TwoParameterConsumer; +import io.bioimage.modelrunner.engine.EngineInfo; +import io.bioimage.modelrunner.engine.installation.EngineManagement; +import io.bioimage.modelrunner.exceptions.LoadEngineException; +import io.bioimage.modelrunner.model.Model; +import io.bioimage.modelrunner.versionmanagement.DeepLearningVersion; +import io.bioimage.modelrunner.versionmanagement.InstalledEngines; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.stream.Collectors; import javax.swing.JButton; @@ -141,15 +152,34 @@ public class DeepImageJ_Run implements PlugIn, ItemListener, Runnable, ActionLis // has been open or not for the testing private int nOpenImages = 0; + /** + * List of the installed DL frameworks compatible with this OS + */ + private List installedEngines; + + /** + * Create the String to engines directory + */ + private static final String JARS_DIRECTORY = new File("engines").getAbsolutePath(); + /** + * Track of threads that have been opened during execution and have to be closed + */ + private ArrayList extraThreads = new ArrayList(); static public void main(String args[]) { +<<<<<<< HEAD path = System.getProperty("user.home") + File.separator + "Google Drive" + File.separator + "ImageJ" + File.separator + "models" + File.separator; path = "C:\\Users\\Carlos(tfg)\\Pictures\\Fiji.app\\models" + File.separator; path = "C:\\Users\\angel\\OneDrive\\Documentos\\deepimagej\\fiji-win64\\Fiji.app\\models" + File.separator; //ImagePlus imp = IJ.openImage("C:\\Users\\Carlos(tfg)\\Desktop\\Fiji.app\\models\\Usiigaci_2.1.4\\usiigaci.tif"); ImagePlus imp = IJ.openImage("C:\\Users\\angel\\OneDrive\\Documentos\\deepimagej\\fiji-win64\\Fiji.app\\models\\pt\\sample_input_0.tif"); //ImagePlus imp = IJ.createImage("aux", 64, 64, 1, 24); +======= + path = new File("models").getAbsolutePath(); + ImagePlus imp = null; + //imp = IJ.openImage(path + "\\pr\\sample_input_0.tif"); +>>>>>>> development if (imp != null) imp.show(); WindowManager.setTempCurrentImage(imp); new DeepImageJ_Run().run(""); @@ -157,10 +187,13 @@ static public void main(String args[]) { @Override public void run(String arg) { + System.out.println("engines jars directory is "+JARS_DIRECTORY); + testMode = false; headless = GraphicsEnvironment.isHeadless(); +// headless = true; // true only for debug headless testing isMacro = IJ.isMacro(); @@ -228,9 +261,13 @@ public void run(String arg) { if (isMacro || headless) { // Macro argument String macroArg = Macro.getOptions(); - // Names of the variables needed to run DIJ - // Especially Pytorch, add the possibility of including - // the path to the model directory. See DeepImageJ wiki for more + +// macroArg = "model=NucleiSegmentationBoundaryModel format=Onnx preprocessing=[zero_mean_unit_variance.ijm] postprocessing=[no postprocessing] axes=C,Y,X tile=1,288,288 logging=Normal"; +// macroArg = "model=[StarDist H&E Nuclei Segmentation] format=Tensorflow preprocessing=[per_sample_scale_range.ijm] postprocessing=[no postprocessing] axes=Y,X,C tile=496,704,3 logging=Normal"; + + /* Names of the variables needed to run DIJ + Especially Pytorch, add the possibility of including + the path to the model directory. See DeepImageJ wiki for more */ String[] varNames = new String[] {"model", "format", "preprocessing", "postprocessing", "axes", "tile", "logging", "model_dir"}; try { @@ -247,16 +284,23 @@ public void run(String arg) { // If it is a macro, load models, tf and pt directly in the same thread. // If this was done in another thread, the plugin would try to execute the // models before everything was ready + System.out.println("[DEBUG] Start loading models"); loadModels(); - String engine = args[1]; - // In macro or headless mode, only the needed engines are loaded - boolean loadTf = engine.toLowerCase().contentEquals("tensorflow"); - boolean loadPt = engine.toLowerCase().contentEquals("pytorch"); - loadTfAndPytorch(loadTf, loadPt); + System.out.println("[DEBUG] Finished loading models"); + try { + findAvailableEngines(); + System.out.println("[DEBUG] Engines found"); + } catch (IOException e) { + IJ.error("Unable to find an engines directory. Please create" + + System.lineSeparator() + "a folder called" + + " engines inside the ImageJ/Fiji folder."); + } // Get the index of the model selected in the list of models String index = Integer.toString(Index.indexOf(items, args[0])); // Select the model name using its index in the list args[0] = fullnames.get(index); + // Put the framework name in lower case + args[1] = args[1].toLowerCase(); } else { args = createAndShowDialog(); } @@ -331,6 +375,7 @@ public String[] createAndShowDialog() { // th plugin is not run from a macro Thread thread = new Thread(this); thread.start(); + extraThreads.add(thread); // Set the 'ok' button and the model choice // combo box disabled until Tf and Pt are loaded @@ -347,11 +392,10 @@ public String[] createAndShowDialog() { return null; } for (String kk : dps.keySet()) { - if (dps.get(kk).getTfModel() != null) - dps.get(kk).getTfModel().close(); - else if (dps.get(kk).getTorchModel() != null) - dps.get(kk).getTorchModel().close(); + if (dps.get(kk).getModel() != null) + dps.get(kk).getModel().closeModel(); } + this.closeAllThreads(); return null; } String[] args = retrieveDialogParamas(); @@ -363,11 +407,13 @@ public String[] retrieveDialogParamas() { String index = Integer.toString(choices[0].getSelectedIndex()); if ((index.equals("-1") || index.equals("0")) && loadedEngine) { IJ.error("Select a valid model."); - run(""); + if (!this.isMacro && !this.headless) + run(""); return null; } else if ((choices[0].getSelectedIndex() == -1 ||choices[0].getSelectedIndex() == 0) && !loadedEngine) { IJ.error("Please wait until the Deep Learning engines are loaded."); - run(""); + if (!this.isMacro && !this.headless) + run(""); return null; } @@ -416,245 +462,288 @@ public String[] retrieveDialogParamas() { public void arrangeParametersAndRunModel(ImagePlus imp, String[] args) { // If the args are null, something wrong happened - if (args == null && (headless || isMacro)) { - IJ.error("Incorrect Macro call"); - return; - } else if (args == null) { - return; - } else if ((headless || isMacro) && dps.keySet().size() == 0) { - // If no models have been found, do nothing and stop execution - return; - } - // Get the arguments for the model execution - String dirname = args[0]; String format = args[1]; processingFile[0] = args[2]; - processingFile[1] = args[3]; String patchString = args[5]; String debugMode = args[6]; - - dp = dps.get(dirname); - - // If the plugin is running in test mode, get the test image - // that has just been displayed - int currentImagesOpen = WindowManager.getImageTitles().length; - // Check if there has been an image opened, checking the number - // of images open now vs at the begining - boolean imageHasBeenOpened = currentImagesOpen > nOpenImages; - if (testMode && !isMacro && WindowManager.getCurrentImage() != null && imageHasBeenOpened) { - // Set batch mode to false - batch = false; - imp = WindowManager.getCurrentImage(); - // Get basic specifications for the input from the yaml - String tensorForm = dp.params.inputList.get(0).form; - // Minimum size if it is not fixed, 0s if it is - int[] tensorMin = dp.params.inputList.get(0).minimum_size; - // Step if the size is not fixed, 0s if it is - int[] tensorStep = dp.params.inputList.get(0).step; - float[] haloSize = ArrayOperations.findTotalPadding(dp.params.inputList.get(0), dp.params.outputList, dp.params.pyramidalNetwork); - // Get the minimum tile size given by the yaml without batch - int[] min = DijTensor.getWorkingDimValues(tensorForm, tensorMin); - // Get the step given by the yaml without batch - int[] step = DijTensor.getWorkingDimValues(tensorForm, tensorStep); - // Get the halo given by the yaml without batch - float[] haloVals = DijTensor.getWorkingDimValues(tensorForm, haloSize); - // Get the axes given by the yaml without batch - String[] dim = DijTensor.getWorkingDims(tensorForm); - patchString = ArrayOperations.optimalPatch(haloVals, dim, step, min, dp.params.allowPatching); - } else if (testMode && !isMacro) { - // If no image has been displayed there is an error - String err = "No test image has been found in the model folder.\n" - + "There should be an image called: "; - // REtieve the images names - String imageName = dp.params.inputList.get(0).exampleInput; - err += imageName; - // Path to the test image specified in the rdf.yaml in - // the >sample_inputs - String imageName2 = null; - if (dp.params.sampleInputs != null && dp.params.sampleInputs.length != 0) { - imageName2 = dp.params.sampleInputs[0]; - err += " or " + imageName2; - } - IJ.error(err); - run(""); - return; - } - // Check if the patxh size is editable or not - boolean patchEditable = false; - if (!headless && !isMacro && texts[1].isEditable()) - patchEditable = true; - - if (debugMode.equals("debug")) { - log.setLevel(2); - } else if (debugMode.equals("normal")) { - log.setLevel(1); - } else if (debugMode.equals("mute")) { - log.setLevel(0); - } - - if (log.getLevel() >= 1) - log.print("Load model: " + dp.getName() + "(" + dirname + ")"); - - dp.params.framework = format.toLowerCase().contains("pytorch") ? "pytorch" : "tensorflow"; - // Select the needed attachments for the version used - if (dp.params.framework.toLowerCase().contentEquals("pytorch")) { - dp.params.attachments = dp.params.ptAttachments; - } else if (dp.params.framework.toLowerCase().contentEquals("tensorflow")) { - dp.params.attachments = dp.params.tfAttachments; - } - - if (!headless && !isMacro) { - info.setText(""); - info.setCaretPosition(0); - info.append("Loading model. Please wait...\n"); - } - + if (args == null && (headless || isMacro)) { + IJ.error("Incorrect Macro call"); + return; + } else if (args == null) { + return; + } else if ((headless || isMacro) && dps.keySet().size() == 0) { + // If no models have been found, do nothing and stop execution + return; + } + // Get the arguments for the model execution + String dirname = args[0]; String finalFormat = args[1]; processingFile[0] = args[2]; + processingFile[1] = args[3]; String patchString = args[5]; String debugMode = args[6]; + + dp = dps.get(dirname); + + // If the plugin is running in test mode, get the test image + // that has just been displayed + int currentImagesOpen = WindowManager.getImageTitles().length; + // Check if there has been an image opened, checking the number + // of images open now vs at the begining + boolean imageHasBeenOpened = currentImagesOpen > nOpenImages; + if (testMode && !isMacro && WindowManager.getCurrentImage() != null && imageHasBeenOpened) { + // Set batch mode to false + batch = false; + imp = WindowManager.getCurrentImage(); + // Get basic specifications for the input from the yaml + String tensorForm = dp.params.inputList.get(0).form; + // Minimum size if it is not fixed, 0s if it is + int[] tensorMin = dp.params.inputList.get(0).minimum_size; + // Step if the size is not fixed, 0s if it is + int[] tensorStep = dp.params.inputList.get(0).step; + float[] haloSize = ArrayOperations.findTotalPadding(dp.params.inputList.get(0), dp.params.outputList, dp.params.pyramidalNetwork); + // Get the minimum tile size given by the yaml without batch + int[] min = DijTensor.getWorkingDimValues(tensorForm, tensorMin); + // Get the step given by the yaml without batch + int[] step = DijTensor.getWorkingDimValues(tensorForm, tensorStep); + // Get the halo given by the yaml without batch + float[] haloVals = DijTensor.getWorkingDimValues(tensorForm, haloSize); + // Get the axes given by the yaml without batch + String[] dim = DijTensor.getWorkingDims(tensorForm); + patchString = ArrayOperations.optimalPatch(haloVals, dim, step, min, dp.params.allowPatching); + } else if (testMode && !isMacro) { + // If no image has been displayed there is an error + String err = "No test image has been found in the model folder.\n" + + "There should be an image called: "; + // REtieve the images names + String imageName = dp.params.inputList.get(0).exampleInput; + err += imageName; + // Path to the test image specified in the rdf.yaml in + // the >sample_inputs + String imageName2 = null; + if (dp.params.sampleInputs != null && dp.params.sampleInputs.length != 0) { + imageName2 = dp.params.sampleInputs[0]; + err += " or " + imageName2; + } + IJ.error(err); + if (!this.isMacro && !this.headless) + run(""); + return; + } + // Check if the patxh size is editable or not + boolean patchEditable = false; + if (!headless && !isMacro && texts[1].isEditable()) + patchEditable = true; + + if (debugMode.equals("debug")) { + log.setLevel(2); + } else if (debugMode.equals("normal")) { + log.setLevel(1); + } else if (debugMode.equals("mute")) { + log.setLevel(0); + } + + if (log.getLevel() >= 1) + log.print("Load model: " + dp.getName() + "(" + dirname + ")"); + + List engineNamesList = dp.params.weights.getEnginesListWithVersions(); + dp.params.framework = finalFormat; + String format; + if (finalFormat.equals("pytorch")) { + format = "torchscript"; + } else if (finalFormat.equals("tensorflow")) { + format = "tensorflow_saved_model_bundle"; + } else if (finalFormat.equals("onnx")) { + format = "onnx"; + } else { + throw new IllegalArgumentException("Selected 'Format' is not suppported. Only 'Formats' " + System.lineSeparator() + + "supported are Tensorflow, Pytorch and Onnx"); + } + String engineSelected = + engineNamesList.stream().filter(i -> i.startsWith(format)).findFirst().orElse(null); + String source; + String engine; + String version; + try { + engine = dp.params.weights.getWeightsByIdentifier(engineSelected).getWeightsFormat(); + source = dp.params.weights.getWeightsByIdentifier(engineSelected).getSource(); + source = dp.getPath() + File.separator + new File(source).getName(); + version = dp.params.weights.getWeightsByIdentifier(engineSelected).getTrainingVersion(); + } catch (IOException e1) { + IJ.error("The selected model does not contains source file for the selected weights."); + if (!this.isMacro && !this.headless) + run(""); + return; + } + + if (!headless && !isMacro) { + info.setText(""); + info.setCaretPosition(0); + info.append("Loading model. Please wait...\n"); + } - dp.params.firstPreprocessing = null; - dp.params.secondPreprocessing = null; - dp.params.firstPostprocessing = null; - dp.params.secondPostprocessing = null; - - if (!processingFile[0].equals("no preprocessing")) { - // Workaround for ImageJ Macros. - // DeepImageJ always writes the pre and post-processing between brackets, - // however when runnning the plugin for a macro this does not happen when there is only - // one processing file. This workaround adds the brackets - if (isMacro && !processingFile[0].startsWith("[")) - processingFile[0] = "[" + processingFile[0]; - if (isMacro && !processingFile[0].endsWith("]")) - processingFile[0] = processingFile[0] + "]"; - String[] preprocArray = processingFile[0].substring(processingFile[0].indexOf("[") + 1, processingFile[0].lastIndexOf("]")).split(","); - dp.params.firstPreprocessing = dp.getPath() + File.separator + preprocArray[0].trim(); - if (preprocArray.length > 1) { - dp.params.secondPreprocessing = dp.getPath() + File.separator + preprocArray[1].trim(); - } - } - - if (!processingFile[1].equals("no postprocessing")) { - // Workaround for ImageJ Macros. - if (isMacro && !processingFile[1].startsWith("[")) - processingFile[1] = "[" + processingFile[1]; - if (isMacro && !processingFile[1].endsWith("]")) - processingFile[1] = processingFile[1] + "]"; - String[] postprocArray = processingFile[1].substring(processingFile[1].indexOf("[") + 1, processingFile[1].lastIndexOf("]")).split(","); - dp.params.firstPostprocessing = dp.getPath() + File.separator + postprocArray[0].trim(); - if (postprocArray.length > 1) { - dp.params.secondPostprocessing = dp.getPath() + File.separator + postprocArray[1].trim(); - } - } - // TODO generalise for several image inputs - for (DijTensor inp: dp.params.inputList) { - String tensorForm = inp.form; - int[] tensorStep = inp.step; - int[] step = DijTensor.getWorkingDimValues(tensorForm, tensorStep); - String[] dims = DijTensor.getWorkingDims(tensorForm); - - float[] haloSize = ArrayOperations.findTotalPadding(inp, dp.params.outputList, dp.params.pyramidalNetwork); - // haloSize is null if any of the offset definitions of the outputs is not a multiple of 0.5 - if (haloSize == null) { - IJ.error("The rdf.yaml of this model contains an error at 'outputs>shape>offset'.\n" - + "The output offsets defined in the rdf.yaml should be multiples of 0.5.\n" - + " If not, the outputs defined will not have a round number of pixels, which\n" - + "is impossible."); - // Relaunch the plugin - closeAndReopenPlugin(imp); - return; - } - - patch = ArrayOperations.getPatchSize(dims, inp.form, patchString, patchEditable); - if (patch == null) { - IJ.error("Please, introduce the patch size as integers separated by commas.\n" - + "For the axes order 'Y,X,C' with:\n" - + "Y=256, X=256 and C=1, we need to introduce:\n" - + "'256,256,1'\n" - + "Note: the key 'auto' can only be used by the plugin."); - // Relaunch the plugin - closeAndReopenPlugin(imp); - return; - } - - for (int i = 0; i < patch.length; i ++) { - if(haloSize[i] * 2 >= patch[i] && patch[i] != -1) { - String errMsg = "Error: Tiles cannot be smaller or equal than 2 times the halo at any dimension.\n" - + "Please, either choose a bigger tile size or change the halo in the rdf.yaml."; - IJ.error(errMsg); - // Relaunch the plugin - closeAndReopenPlugin(imp); - return; - } - } - for (int i = 0; i < inp.minimum_size.length; i ++) { - if (inp.step[i] != 0 && (patch[i] - inp.minimum_size[i]) % inp.step[i] != 0 && patch[i] != -1 && dp.params.allowPatching) { - int approxTileSize = ((patch[i] - inp.minimum_size[i]) / inp.step[i]) * inp.step[i] + inp.minimum_size[i]; - IJ.error("Tile size at dim: " + tensorForm.split("")[i] + " should be product of:\n " + inp.minimum_size[i] + - " + " + step[i] + "*N, where N can be any integer >= 0.\n" - + "The immediately smaller valid tile size is " + approxTileSize); - // Relaunch the plugin - closeAndReopenPlugin(imp); - return; - } else if (inp.step[i] == 0 && patch[i] != inp.minimum_size[i]) { - IJ.error("Patch size at dim: " + tensorForm.split("")[i] + " should be " + inp.minimum_size[i]); - // Relaunch the plugin - closeAndReopenPlugin(imp); - return; - } - } - } - dp.params.inputList.get(0).recommended_patch = patch; + dp.params.firstPreprocessing = null; + dp.params.secondPreprocessing = null; + dp.params.firstPostprocessing = null; + dp.params.secondPostprocessing = null; + + if (!processingFile[0].equals("no preprocessing")) { + // Workaround for ImageJ Macros. + // DeepImageJ always writes the pre and post-processing between brackets, + // however when runnning the plugin for a macro this does not happen when there is only + // one processing file. This workaround adds the brackets + if (isMacro && !processingFile[0].startsWith("[")) + processingFile[0] = "[" + processingFile[0]; + if (isMacro && !processingFile[0].endsWith("]")) + processingFile[0] = processingFile[0] + "]"; + String[] preprocArray = processingFile[0].substring(processingFile[0].indexOf("[") + 1, processingFile[0].lastIndexOf("]")).split(","); + dp.params.firstPreprocessing = dp.getPath() + File.separator + preprocArray[0].trim(); + if (preprocArray.length > 1) { + dp.params.secondPreprocessing = dp.getPath() + File.separator + preprocArray[1].trim(); + } + } + + if (!processingFile[1].equals("no postprocessing")) { + // Workaround for ImageJ Macros. + if (isMacro && !processingFile[1].startsWith("[")) + processingFile[1] = "[" + processingFile[1]; + if (isMacro && !processingFile[1].endsWith("]")) + processingFile[1] = processingFile[1] + "]"; + String[] postprocArray = processingFile[1].substring(processingFile[1].indexOf("[") + 1, processingFile[1].lastIndexOf("]")).split(","); + dp.params.firstPostprocessing = dp.getPath() + File.separator + postprocArray[0].trim(); + if (postprocArray.length > 1) { + dp.params.secondPostprocessing = dp.getPath() + File.separator + postprocArray[1].trim(); + } + } - ExecutorService service = Executors.newFixedThreadPool(1); - RunnerProgress rp = null; - if (!headless) { - rp = new RunnerProgress(dp, "load", service); - } - else { - System.out.println("[DEBUG] Loading model"); - } + // TODO generalise for several image inputs + for (DijTensor inp: dp.params.inputList) { + String tensorForm = inp.form; + int[] tensorStep = inp.step; + int[] step = DijTensor.getWorkingDimValues(tensorForm, tensorStep); + String[] dims = DijTensor.getWorkingDims(tensorForm); - if (rp!= null && dp.params.framework.contains("tensorflow") && !(new File(dp.getPath() + File.separator + "variables").exists())) { - info.append("Unzipping Tensorflow model. Please wait...\n"); - rp.setUnzipping(true); - } - - boolean iscuda = DeepLearningModel.TensorflowCUDACompatibility(loadInfo, cudaVersion).equals(""); - ModelLoader loadModel = new ModelLoader(dp, rp, loadInfo.contains("GPU"), iscuda, log.getLevel() >= 1, SystemUsage.checkFiji()); + float[] haloSize = ArrayOperations.findTotalPadding(inp, dp.params.outputList, dp.params.pyramidalNetwork); + // haloSize is null if any of the offset definitions of the outputs is not a multiple of 0.5 + if (haloSize == null) { + IJ.error("The rdf.yaml of this model contains an error at 'outputs>shape>offset'.\n" + + "The output offsets defined in the rdf.yaml should be multiples of 0.5.\n" + + " If not, the outputs defined will not have a round number of pixels, which\n" + + "is impossible."); + // Relaunch the plugin + closeAndReopenPlugin(imp); + return; + } + + patch = ArrayOperations.getPatchSize(dims, inp.form, patchString, patchEditable); + if (patch == null) { + IJ.error("Please, introduce the patch size as integers separated by commas.\n" + + "For the axes order 'Y,X,C' with:\n" + + "Y=256, X=256 and C=1, we need to introduce:\n" + + "'256,256,1'\n" + + "Note: the key 'auto' can only be used by the plugin."); + // Relaunch the plugin + closeAndReopenPlugin(imp); + return; + } - Future f1 = service.submit(loadModel); - boolean output = false; - try { - output = f1.get(); - } catch (InterruptedException | ExecutionException e) { - if (rp != null && rp.getUnzipping()) - IJ.error("Unable to unzip model"); - else - IJ.error("Unable to load model"); - e.printStackTrace(); - if (rp != null) - rp.stop(); - } - - - // If the user has pressed stop button, stop execution and return - if (rp != null && rp.isStopped()) { - service.shutdown(); - rp.dispose(); - // Free memory allocated by the plugin - freeIJMemory(dlg, imp); + for (int i = 0; i < patch.length; i ++) { + if(haloSize[i] * 2 >= patch[i] && patch[i] != -1) { + String errMsg = "Error: Tiles cannot be smaller or equal than 2 times the halo at any dimension.\n" + + "Please, either choose a bigger tile size or change the halo in the rdf.yaml."; + IJ.error(errMsg); + // Relaunch the plugin + closeAndReopenPlugin(imp); return; } - - // If the model was not loaded, run again the plugin - if (!output) { - IJ.error("Load model error: " + (dp.getTfModel() == null || dp.getTorchModel() == null)); - service.shutdown(); - if (!isMacro && !headless) - run(""); + } + for (int i = 0; i < inp.minimum_size.length; i ++) { + if (inp.step[i] != 0 && (patch[i] - inp.minimum_size[i]) % inp.step[i] != 0 && patch[i] != -1 && dp.params.allowPatching) { + int approxTileSize = ((patch[i] - inp.minimum_size[i]) / inp.step[i]) * inp.step[i] + inp.minimum_size[i]; + IJ.error("Tile size at dim: " + tensorForm.split("")[i] + " should be product of:\n " + inp.minimum_size[i] + + " + " + step[i] + "*N, where N can be any integer >= 0.\n" + + "The immediately smaller valid tile size is " + approxTileSize); + // Relaunch the plugin + closeAndReopenPlugin(imp); + return; + } else if (inp.step[i] == 0 && patch[i] != inp.minimum_size[i]) { + IJ.error("Patch size at dim: " + tensorForm.split("")[i] + " should be " + inp.minimum_size[i]); + // Relaunch the plugin + closeAndReopenPlugin(imp); return; } - - if (rp != null) - rp.setService(null); + } + } + dp.params.inputList.get(0).recommended_patch = patch; + + ExecutorService service = Executors.newFixedThreadPool(1); + RunnerProgress rp = null; + if (!headless) { + rp = new RunnerProgress(dp, "load", service); + } + else { + System.out.println("[DEBUG] Loading model"); + } + + if (rp!= null && dp.params.framework.contains("tensorflow") && !(new File(dp.getPath() + File.separator + "variables").exists())) { + info.append("Unzipping Tensorflow model. Please wait...\n"); + rp.setUnzipping(true); + } + + boolean iscuda = DeepLearningModel.TensorflowCUDACompatibility(loadInfo, cudaVersion).equals(""); + + EngineInfo engineInfo; + Model model; + try { + engineInfo = EngineInfo.defineCompatibleDLEngine(engine, version, JARS_DIRECTORY); + if (engineInfo == null) + throw new Exception("No compatible engine installed." + System.lineSeparator() + + "Required engine: " + engine + " " + version); + model = Model.createDeepLearningModel(dp.getPath(), source, engineInfo, getClass().getClassLoader()); + } catch (LoadEngineException e1) { + IJ.error("Error loading " + engine + System.lineSeparator() + e1.toString()); + if (!this.isMacro && !this.headless) + run(""); + return; + } catch (Exception e1) { + IJ.error("Error loading " + engine + System.lineSeparator() + e1.toString()); + if (!this.isMacro && !this.headless) + run(""); + return; + } + ModelLoader loadModel = new ModelLoader(dp, model, rp, loadInfo.contains("GPU"), iscuda, log.getLevel() >= 1); + + Future f1 = service.submit(loadModel); + boolean output = false; + try { + output = f1.get(); + } catch (InterruptedException | ExecutionException e) { + if (rp != null && rp.getUnzipping()) + IJ.error("Unable to unzip model"); + else + IJ.error("Unable to load model"); + e.printStackTrace(); + if (rp != null) + rp.stop(); + } + + + // If the user has pressed stop button, stop execution and return + if (rp != null && rp.isStopped()) { + service.shutdown(); + rp.dispose(); + // Free memory allocated by the plugin + freeIJMemory(dlg, imp); + return; + } + + // If the model was not loaded, run again the plugin + if (!output) { + IJ.error("Load model error: " + (dp.getModel() == null)); + service.shutdown(); + if (!isMacro && !headless) + run(""); + return; + } + + if (rp != null) + rp.setService(null); - calculateImage(imp, rp, service); - service.shutdown(); + calculateImage(imp, rp, service); + service.shutdown(); } /** @@ -899,7 +988,11 @@ public void setDLEngine(DeepImageJ dp) { } else if (dp.params.framework.toLowerCase().equals("tensorflow")) { choices[1].removeAll(); choices[1].addItem("Tensorflow"); + } else if (dp.params.framework.toLowerCase().equals("onnx")) { + choices[1].removeAll(); + choices[1].addItem("Onnx"); } + } /** @@ -978,19 +1071,12 @@ public void calculateImage(ImagePlus inp, RunnerProgress rp, ExecutorService ser if (log.getLevel() >= 1) log.print("start runner"); HashMap output = null; - if (dp.params.framework.equals("tensorflow")) { - RunnerTf runner = new RunnerTf(dp, rp, inputsMap, log); - if (rp != null) - rp.setRunner(runner); - Future> f1 = service.submit(runner); - output = f1.get(); - } else { - RunnerPt runner = new RunnerPt(dp, rp, inputsMap, log); - if (rp != null) - rp.setRunner(runner); - Future> f1 = service.submit(runner); - output = f1.get(); - } + + RunnerDL runner = new RunnerDL(dp, rp, inputsMap, log); + if (rp != null) + rp.setRunner(runner); + Future> f1 = service.submit(runner); + output = f1.get(); inp.changes = false; inp.close(); @@ -1079,14 +1165,9 @@ public void freeIJMemory(GenericDialog dlg, ImagePlus imp) { // If it is not headless, there is no GUI, no need to close it if (!headless && !isMacro && !testMode) dlg.dispose(); - // Close the IJ2 services to free all the resources used - if (SystemUsage.checkFiji()) - StartTensorflowService.closeTfService(); - if (dp != null && dp.params.framework.equals("tensorflow") && dp.getTfModel() != null) { - dp.getTfModel().session().close(); - dp.getTfModel().close(); - } else if (dp != null && dp.params.framework.equals("pytorch") && dp.getTorchModel() != null) { - dp.getTorchModel().close(); + if (dp != null && dp.getModel() != null) { + dp.getModel().closeModel(); + dp.setModel(null); } this.dp = null; this.dps = null; @@ -1156,8 +1237,8 @@ public void loadModels(String modelDir) { * no input is provided, loads both engines * */ - public void loadTfAndPytorch() { - loadTfAndPytorch(true, true); + public void loadTfAndPytorch() throws IOException { + findAvailableEngines(); } /* @@ -1165,59 +1246,67 @@ public void loadTfAndPytorch() { * the DJL takes some time. Normally the GUI would not sho until everything is loaded. * In order to show the DeepImageJ Run GUI fast, Tf and Pt are loaded in a separate thread. * - * @param tf - * load Tensorflow library or not - * @param pt - * load Pytorch library or not */ - public void loadTfAndPytorch(boolean tf, boolean pt) { - loadInfo = "ImageJ"; - boolean isFiji = SystemUsage.checkFiji(); + public void findAvailableEngines() throws IOException { + loadInfo = ""; // FOrmat for the date Date now = new Date(); if (!headless && !isMacro) { - info.append("\n\n"); - info.append(" - " + new SimpleDateFormat("HH:mm:ss").format(now) + " -- LOADING TENSORFLOW JAVA (might take some time)\n"); + info.append(System.lineSeparator()); + info.append(" - " + new SimpleDateFormat("HH:mm:ss").format(now) + + " -- CHECKING THE REQUIRED ENGINES ARE INSTALLED"); + info.append(System.lineSeparator()); } - // First load Tensorflow - if (isFiji && (!(headless || isMacro) || tf)) { - loadInfo = StartTensorflowService.loadTfLibrary(); - } else if (!(headless || isMacro) || pt) { - // In order to get Pytorch to work we have to set - // the IJ ClassLoader as the ContextClassLoader - Thread.currentThread().setContextClassLoader(IJ.getClassLoader()); + + EngineManagement engineManager = EngineManagement.createManager(); + Map> consumers = + new LinkedHashMap>(); + Thread checkAndInstallMissingEngines = new Thread(() -> { + consumers.putAll(engineManager.getBasicEnginesProgress()); + engineManager.basicEngineInstallation(); + }); + extraThreads.add(checkAndInstallMissingEngines); + System.out.println("[DEBUG] Checking and installing missing engines"); + checkAndInstallMissingEngines.start(); + + String backup = null; + if (!headless && !isMacro) { + backup = info.getText(); } - // If the version allows GPU, find if there is CUDA - if (loadInfo.contains("GPU")) - cudaVersion = SystemUsage.getCUDAEnvVariables(); + EngineInstaller installerInfo = new EngineInstaller(); + while (!engineManager.isManagementDone()) { + try {Thread.sleep(300);} catch (InterruptedException e) {} + if ((!headless && !isMacro) && consumers.keySet().size() != 0) { + String progress = installerInfo.basicEnginesInstallationProgress(consumers); + info.setText(backup + System.lineSeparator() + progress); + info.setCaretPosition(info.getText().length()); + } + } + + installedEngines = InstalledEngines.buildEnginesFinder().getDownloadedForOS(); + List engineNames = + installedEngines.stream() + .map(v -> v.getEngine() + "-" + + v.getPythonVersion() + " (GPU: " + v.getGPU() + ") " + + getCudaVersionsCompatible(v.getEngine(), v.getPythonVersion(), v.getGPU())) + .collect(Collectors.toList()); - if (loadInfo.equals("")) { - loadInfo += "No Tensorflow library found.\n"; - loadInfo += "Please install a new Tensorflow version.\n"; - } else if (loadInfo.equals("ImageJ") && (!headless || tf)) { - loadInfo = "Currently using TensorFlow "; - loadInfo += DeepLearningModel.getTFVersion(false); - if (!loadInfo.contains("GPU")) - loadInfo += "_CPU"; - loadInfo += ".\n"; - loadInfo += "To change the version, consult the DeepImageJ Wiki.\n"; + if (engineNames.size() == 0) { + loadInfo += "No Deep Learning frameworks installed, please install." + System.lineSeparator(); } else { - loadInfo += ".\n"; - loadInfo += "To change the TF version go to Edit>Options>Tensorflow.\n"; - } - // Then find Pytorch the Pytorch version - if (!headless && !isMacro) - info.append(" - " + new SimpleDateFormat("HH:mm:ss").format(now) + " -- LOADING DJL PYTORCH\n"); - String ptVersion = null; - if (!(headless || isMacro) || pt) - ptVersion = DeepLearningModel.getPytorchVersion(); - loadInfo += "\n"; - loadInfo += "Currently using Pytorch " + ptVersion + ".\n"; - loadInfo += "Supported by Deep Java Library " + ptVersion + ".\n"; - - if (!headless && !isMacro) + loadInfo += "Available Deep Learning frameworks:" + System.lineSeparator(); + for (String names : engineNames) + loadInfo += " -" + names + System.lineSeparator(); + } + loadInfo += System.lineSeparator(); + System.out.println("[DEBUG] " + loadInfo); + + // If the version allows GPU, find if there is CUDA + if (!headless && !isMacro) { info.append(" - " + new SimpleDateFormat("HH:mm:ss").format(now) + " -- Looking for installed CUDA distributions"); - getCUDAInfo(loadInfo, ptVersion, cudaVersion); + cudaVersion = SystemUsage.getCUDAEnvVariables(); + loadInfo += "Installed CUDA versions: " + cudaVersion + System.lineSeparator(); + } loadInfo += "Models' path: " + DeepImageJ.cleanPathStr(path) + "\n"; loadInfo += "\n"; @@ -1229,6 +1318,25 @@ public void loadTfAndPytorch(boolean tf, boolean pt) { } } + /* + * Find out which CUDA version it is being used and if its use is viable toguether with + * Tensorflow + */ + public String getCudaVersionsCompatible(String engine, String version, boolean gpu) { + String cudas = null; + if (engine.equals(EngineInfo.getPytorchKey()) && SystemUsage.MAP_PYTORCH_CUDA.get(version) != null) { + cudas = SystemUsage.MAP_PYTORCH_CUDA.get(version).toString(); + } else if (engine.equals(EngineInfo.getTensorflowKey()) && SystemUsage.MAP_TF_CUDA.get(version) != null) { + cudas = SystemUsage.MAP_TF_CUDA.get(version).toString(); + } else if (engine.equals(EngineInfo.getOnnxKey()) && SystemUsage.MAP_ONNX_CUDA.get(version) != null) { + cudas = SystemUsage.MAP_ONNX_CUDA.get(version).toString(); + } + if (cudas == null){ + cudas = ""; + } + return cudas; + } + /* * Find out which CUDA version it is being used and if its use is viable toguether with * Tensorflow @@ -1323,7 +1431,13 @@ private void setMissingYamlText() { */ public void run() { loadModels(); - loadTfAndPytorch(); + try { + loadTfAndPytorch(); + } catch (IOException ex) { + IJ.error("Unable to find an engines directory. Please create" + + System.lineSeparator() + "a folder called" + + " engines inside the ImageJ/Fiji folder."); + } } @Override @@ -1334,12 +1448,14 @@ public void actionPerformed(ActionEvent e) { DeepImageJ dp = dps.get(dirname); // Path to the test image specified in the rdf.yaml in // the >config>deepimagej>test_information part - String imageName = dp.getPath() + dp.params.inputList.get(0).exampleInput; + String imageName = ""; + if (dp.params.inputList != null && dp.params.inputList.get(0).exampleInput != null) + imageName = dp.getPath() + new File(dp.params.inputList.get(0).exampleInput).getName(); // Path to the test image specified in the rdf.yaml in // the >sample_inputs String imageName2 = null; if (dp.params.sampleInputs != null && dp.params.sampleInputs.length != 0) - imageName2 = dp.getPath() + dp.params.sampleInputs[0]; + imageName2 = dp.getPath() + new File(dp.params.sampleInputs[0]).getName(); ImagePlus imp = null; // Do not try to read npy files boolean notNpy = !(imageName.endsWith(".npy") || imageName.endsWith(".npx") || imageName.endsWith(".np")); @@ -1355,7 +1471,7 @@ public void actionPerformed(ActionEvent e) { // Do nothing } } - notNpy = !(imageName2.endsWith(".npy") || imageName2.endsWith(".npx") || imageName2.endsWith(".np")); + notNpy = imageName2 != null && !(imageName2.endsWith(".npy") || imageName2.endsWith(".npx") || imageName2.endsWith(".np")); if (!openTest && notNpy && imageName2 != null && new File(imageName2).isFile()) { try{ imp = IJ.openImage(imageName2); @@ -1377,5 +1493,16 @@ public void actionPerformed(ActionEvent e) { okay.getActionListeners()[0].actionPerformed(ee); testMode = true; } + + /** + * Close all the threads that have been opened during the execution + */ + private void closeAllThreads() { + extraThreads.stream().forEach(t -> { + if (t != null) + t.interrupt(); + t = null; + }); + } } \ No newline at end of file diff --git a/src/main/java/deepimagej/BuildDialog.java b/src/main/java/deepimagej/BuildDialog.java deleted file mode 100755 index 22e1b752..00000000 --- a/src/main/java/deepimagej/BuildDialog.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej; - -import java.awt.BorderLayout; -import java.awt.CardLayout; -import java.awt.Dimension; -import java.awt.GridLayout; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.awt.event.WindowAdapter; -import java.awt.event.WindowEvent; -import java.io.File; - -import javax.swing.JButton; -import javax.swing.JDialog; -import javax.swing.JFrame; -import javax.swing.JPanel; - -import deepimagej.components.TitleHTMLPane; -import deepimagej.stamp.InputDimensionStamp; -import deepimagej.stamp.JavaPostprocessingStamp; -import deepimagej.stamp.JavaPreprocessingStamp; -import deepimagej.stamp.LoadPytorchStamp; -import deepimagej.stamp.InformationStamp; -import deepimagej.stamp.LoadTFStamp; -import deepimagej.stamp.OutputDimensionStamp; -import deepimagej.stamp.PtSaveStamp; -import deepimagej.stamp.SaveOutputFilesStamp; -import deepimagej.stamp.TfSaveStamp; -import deepimagej.stamp.SelectPyramidalStamp; -import deepimagej.stamp.TensorPytorchTmpStamp; -import deepimagej.stamp.TensorStamp; -import deepimagej.stamp.TestStamp; -import deepimagej.stamp.WelcomeStamp; -import deepimagej.tools.Log; -import deepimagej.tools.WebBrowser; -import ij.IJ; -import ij.gui.GUI; - -public class BuildDialog extends JDialog implements ActionListener { - - public JButton bnNext = new JButton("Next"); - public JButton bnBack = new JButton("Back"); - public JButton bnClose = new JButton("Cancel"); - public JButton bnHelp = new JButton("Help"); - private JPanel pnCards = new JPanel(new CardLayout()); - - private WelcomeStamp welcome = null; - private LoadTFStamp loaderTf = null; - private LoadPytorchStamp loaderPt = null; - private SelectPyramidalStamp selectPyramid = null; - private InputDimensionStamp dim3 = null; - private OutputDimensionStamp outputDim = null; - private TensorStamp tensorTf = null; - private TensorPytorchTmpStamp tensorPt = null; - private InformationStamp info = null; - private JavaPreprocessingStamp javaPreproc = null; - private JavaPostprocessingStamp javaPostproc = null; - private TestStamp test2 = null; - private SaveOutputFilesStamp outputSelection = null; - private TfSaveStamp tfSave = null; - private PtSaveStamp ptSave = null; - private DeepImageJ dp = null; - private int card = 1; - private String GPU_TF = "CPU"; - private String GPU_PT = "CPU"; - private boolean Fiji = false; - - public BuildDialog() { - super(new JFrame(), "Build Bundled Model [" + Constants.version + "]"); - } - - public void showDialog() { - - welcome = new WelcomeStamp(this); - loaderTf = new LoadTFStamp(this); - loaderPt = new LoadPytorchStamp(this); - selectPyramid = new SelectPyramidalStamp(this); - dim3 = new InputDimensionStamp(this); - tensorTf = new TensorStamp(this); - tensorPt = new TensorPytorchTmpStamp(this); - info = new InformationStamp(this); - outputDim = new OutputDimensionStamp(this); - javaPreproc = new JavaPreprocessingStamp(this); - javaPostproc = new JavaPostprocessingStamp(this); - test2 = new TestStamp(this); - outputSelection = new SaveOutputFilesStamp(this); - tfSave = new TfSaveStamp(this); - ptSave = new PtSaveStamp(this); - - JPanel pnButtons = new JPanel(new GridLayout(1, 4)); - pnButtons.add(bnClose); - pnButtons.add(bnHelp); - pnButtons.add(bnBack); - pnButtons.add(bnNext); - - pnCards.add(welcome.getPanel(), "1"); - pnCards.add(loaderTf.getPanel(), "2"); - pnCards.add(loaderPt.getPanel(), "20"); - pnCards.add(selectPyramid.getPanel(), "3"); - pnCards.add(tensorTf.getPanel(), "4"); - pnCards.add(tensorPt.getPanel(), "40"); - pnCards.add(dim3.getPanel(), "5"); - pnCards.add(outputDim.getPanel(), "6"); - pnCards.add(info.getPanel(), "7"); - pnCards.add(javaPreproc.getPanel(), "8"); - pnCards.add(javaPostproc.getPanel(), "9"); - pnCards.add(test2.getPanel(), "10"); - pnCards.add(outputSelection.getPanel(), "11"); - pnCards.add(tfSave.getPanel(), "12"); - pnCards.add(ptSave.getPanel(), "120"); - - setLayout(new BorderLayout()); - add(new TitleHTMLPane().getPane(), BorderLayout.NORTH); - add(pnCards, BorderLayout.CENTER); - add(pnButtons, BorderLayout.SOUTH); - - bnNext.addActionListener(this); - bnBack.addActionListener(this); - bnClose.addActionListener(this); - bnHelp.addActionListener(this); - - setResizable(true); - pack(); - setPreferredSize(new Dimension(Constants.width, 300)); - GUI.center(this); - setVisible(true); - bnBack.setEnabled(false); - - // Close model when the plugin is closed - this.addWindowListener(new WindowAdapter() - { - public void windowClosed(WindowEvent e) { - // Release every component of each stamp - pnCards.removeAll(); - removeAll(); - if (dp == null) - return; - if (getDeepPlugin().getTfModel() != null) { - getDeepPlugin().getTfModel().session().close(); - getDeepPlugin().getTfModel().close(); - } else if (getDeepPlugin().getTorchModel() != null) { - getDeepPlugin().getTorchModel().close(); - } - } - public void windowClosing(WindowEvent e) { - // Release every component of each stamp - pnCards.removeAll(); - removeAll(); - if (dp == null) - return; - if (getDeepPlugin().getTfModel() != null) { - getDeepPlugin().getTfModel().session().close(); - getDeepPlugin().getTfModel().close(); - } else if (getDeepPlugin().getTorchModel() != null) { - getDeepPlugin().getTorchModel().close(); - } - } - }); - - } - - private void setCard(String name) { - CardLayout cl = (CardLayout) (pnCards.getLayout()); - if (name.equals("2") && dp.params.framework.equals("pytorch")) - name = "20"; - else if (name.equals("4") && dp.params.framework.equals("pytorch")) - name = "40"; - else if (name.equals("12") && dp.params.framework.equals("pytorch")) - name = "120"; - cl.show(pnCards, name); - } - - @Override - public void actionPerformed(ActionEvent e) { - - bnNext.setText("Next"); - bnNext.setEnabled(true); - if (e.getSource() == bnNext) { - switch (card) { - case 1: - if (welcome.finish()) { - card = 2; - String path = welcome.getModelDir(); - String name = welcome.getModelName(); - if (path != null) { - dp = new DeepImageJ(path, name, true); - if (dp.getTfModel() != null) - dp.getTfModel().close(); - else if (dp.getTorchModel() != null) - dp.getTorchModel().close(); - if (dp != null) { - dp.params.path2Model = path + File.separator + name + File.separator; - if (dp.getValid() && dp.params.framework.contains("tensorflow")) { - loaderTf.init(); - } else if (dp.getValid() && dp.params.framework.contains("pytorch")) { - loaderPt.init(); - } else if (!dp.getValid()) { - IJ.error("Please select a correct model"); - card = 1; - } - } - } - } - break; - case 2: - if (dp.params.framework.contains("tensorflow")) { - card = loaderTf.finish() ? card+1 : card; - } else if (dp.params.framework.contains("pytorch")) { - card = loaderPt.finish() ? card+1 : card; - } - break; - case 3: - card = selectPyramid.finish() ? card+1 : card; - break; - case 4: - if (dp.params.framework.contains("tensorflow")) { - card = tensorTf.finish() ? card+1 : card; - } else if (dp.params.framework.contains("pytorch")) { - card = tensorPt.finish() ? card+1 : card; - } - break; - case 5: - card = dim3.finish() ? card+1 : card; - break; - case 6: - card = outputDim.finish() ? card+1 : card; - break; - case 7: - card = info.finish() ? card+1 : card; - break; - case 8: - card = javaPreproc.finish() ? card+1 : card; - break; - case 9: - card = javaPostproc.finish() ? card+1 : card; - break; - case 11: - card = outputSelection.finish() ? card+1 : card; - break; - case 12: - dispose(); - default: - card = Math.min(12, card + 1); - } - } - if (e.getSource() == bnBack) { - card = Math.max(1, card - 1); - } - if (e.getSource() == bnClose) { - dispose(); - } - if (e.getSource() == bnHelp) { - WebBrowser.openDeepImageJ(); - } - - setCard("" + card); - bnBack.setEnabled(card > 1); - if (card == 4 && dp.params.framework.contains("tensorflow")) - tensorTf.init(); - else if (card == 4 && dp.params.framework.contains("pytorch")) - tensorPt.init(); - else if (card == 5) - dim3.init(); - else if (card == 6) - outputDim.init(); - else if (card == 7) - info.init(); - else if (card == 8) - javaPreproc.init(); - else if (card == 9) - javaPostproc.init(); - else if (card == 10) - test2.init(); - else if (card == 11) - outputSelection.init(); - else if (card == 12) - setEnabledBackNext(true); - - bnNext.setText(card == 12 ? "Finish" : "Next"); - } - - public void setEnabledBackNext(boolean b) { - bnBack.setEnabled(b); - bnNext.setEnabled(b); - } - - public void setEnabledNext(boolean b) { - bnNext.setEnabled(b); - } - - public void setEnabledBack(boolean b) { - bnBack.setEnabled(b); - } - - public void endsTest() { - bnBack.setEnabled(true); - bnNext.setEnabled(true); - bnNext.setText("Next"); - } - - public DeepImageJ getDeepPlugin() { - return dp; - } - - public String getGPUTf() { - return GPU_TF; - } - - public void setGPUTf(String info) { - GPU_TF = info; - } - - public String getGPUPt() { - return GPU_PT; - } - - public void setGPUPt(String info) { - GPU_PT = info; - } - - public boolean getFiji() { - return Fiji; - } - - public void setFiji(boolean fiji) { - Fiji = fiji; - } -} diff --git a/src/main/java/deepimagej/Constants.java b/src/main/java/deepimagej/Constants.java index 6b265963..f9881321 100755 --- a/src/main/java/deepimagej/Constants.java +++ b/src/main/java/deepimagej/Constants.java @@ -47,7 +47,7 @@ public class Constants { public static String url = "https://deepimagej.github.io/deepimagej/"; - public static String version = "2.1.16"; + public static String version = "3.0.1"; public static int width = 120; public static String name = "deepImageJ"; diff --git a/src/main/java/deepimagej/DeepImageJ.java b/src/main/java/deepimagej/DeepImageJ.java index 10cb6e26..5b5df05f 100755 --- a/src/main/java/deepimagej/DeepImageJ.java +++ b/src/main/java/deepimagej/DeepImageJ.java @@ -49,23 +49,20 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; +import java.nio.file.FileSystems; +import java.nio.file.Path; +import java.nio.file.PathMatcher; +import java.nio.file.Paths; import java.text.SimpleDateFormat; import java.util.Date; import java.util.HashMap; -import org.tensorflow.SavedModelBundle; -import ai.djl.MalformedModelException; -import ai.djl.ndarray.NDList; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; import deepimagej.tools.DijTensor; import deepimagej.tools.FileTools; import ij.IJ; import ij.gui.GenericDialog; +import io.bioimage.modelrunner.model.Model; public class DeepImageJ { @@ -76,8 +73,7 @@ public class DeepImageJ { // Specifies if the yaml is present in the model folder public boolean presentYaml = true; private boolean developer = true; - private SavedModelBundle tfModel = null; - private ZooModeltorchModel = null; + private Model model = null; public String ptName = "pytorch_script.pt"; public String ptName2 = "weights-torchscript.pt"; public String tfName = "tensorflow_saved_model_bundle.zip"; @@ -99,7 +95,7 @@ public DeepImageJ(String pathModel, String dirname, boolean dev) { this.params.path2Model = this.path; this.valid = check(p); } catch (Exception ex) { - IJ.log("Unable to read the rdf.yaml specifications file in following fodler.\n" + IJ.log("Unable to read the rdf.yaml specifications file in following folder.\n" + "Please review that the compulsory fields are not missing.\n" + " -" + path); } @@ -128,20 +124,12 @@ public String getName() { return name.replace("\"", ""); } - public ZooModel getTorchModel() { - return torchModel; + public Model getModel() { + return model; } - public void setTorchModel(ZooModel model) { - this.torchModel = model; - } - - public SavedModelBundle getTfModel() { - return tfModel; - } - - public void setTfModel(SavedModelBundle model) { - this.tfModel = model; + public void setModel(Model model) { + this.model = model; } public boolean getValid() { @@ -195,59 +183,14 @@ static public HashMap list(String pathModels, boolean isDeve } - public boolean loadTfModel(boolean archi) { - - double chrono = System.nanoTime(); - SavedModelBundle model; - try { - model = SavedModelBundle.load(path, DeepLearningModel.returnStringTag(params.tag)); - setTfModel(model); - } - catch (Exception e) { - IJ.log("Exception in loading model " + dirname); - IJ.log(e.toString()); - IJ.log(e.getMessage()); - return false; - } - chrono = (System.nanoTime() - chrono) / 1000000.0; - return true; - } - - - public boolean loadPtModel(String path, boolean isFiji) { + public boolean loadModel() { try { + /** TODO + * if (!isFiji) Thread.currentThread().setContextClassLoader(IJ.getClassLoader()); - URL url = new File(new File(path).getParent()).toURI().toURL(); - - String modelName = new File(path).getName(); - modelName = modelName.substring(0, modelName.indexOf(".pt")); - Criteria criteria = Criteria.builder() - .setTypes(NDList.class, NDList.class) - // only search the model in local directory - // "ai.djl.localmodelzoo:{name of the model}" - .optModelUrls(url.toString()) // search models in specified path - //.optArtifactId("ai.djl.localmodelzoo:resnet_18") // defines which model to load - .optModelName(modelName) - .optProgress(new ProgressBar()).build(); - - ZooModel model = ModelZoo.loadModel(criteria); - this.setTorchModel(model); - - } catch (MalformedURLException e) { - e.printStackTrace(); - return false; - } catch (ModelNotFoundException e) { - e.printStackTrace(); - return false; - } catch (MalformedModelException e) { - e.printStackTrace(); - return false; - } catch (IOException e) { - IJ.log("Model not found in the path provided:"); - IJ.log(path); - e.printStackTrace(); - return false; + */ + model.loadModel(); } catch (UnsatisfiedLinkError e) { e.printStackTrace(); IJ.log("DeepImageJ could not load the Pytorch model."); @@ -408,6 +351,7 @@ public boolean check(String path) { } boolean validTf = false; boolean validPt = false; + boolean validOnnx = false; File modelFile = new File(path + "saved_model.pb"); File variableFile = new File(path + "variables"); @@ -422,11 +366,23 @@ public boolean check(String path) { validPt = true; this.params.framework = "pytorch"; } + + // For onnx models, check if the folder contains weights with the name something.onnx + PathMatcher matcher = FileSystems.getDefault().getPathMatcher("glob:**.onnx"); + File folder = new File(path); + for(File f: folder.listFiles()) { + if (matcher.matches(f.toPath())){ + validOnnx = true; + break; + } + } + + if (validTf && validPt) this.params.framework = "tensorflow/pytorch"; - if (!validTf && !validPt) { + if (!validTf && !validPt && !validOnnx) { // Find zipped biozoo model try { validTf = findZippedBiozooModel(dir); @@ -434,8 +390,12 @@ public boolean check(String path) { validTf = false; } } + + if (validOnnx && !validPt && !validTf){ + this.params.framework = "onnx"; + } - return validTf || validPt; + return validTf || validPt || validOnnx; } /* diff --git a/src/main/java/deepimagej/DeepLearningModel.java b/src/main/java/deepimagej/DeepLearningModel.java index 8e20f549..0999b8fb 100755 --- a/src/main/java/deepimagej/DeepLearningModel.java +++ b/src/main/java/deepimagej/DeepLearningModel.java @@ -44,229 +44,12 @@ package deepimagej; -import java.io.File; -import java.net.JarURLConnection; -import java.net.URL; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import org.tensorflow.SavedModelBundle; -import org.tensorflow.TensorFlowException; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; -import org.tensorflow.framework.TensorShapeProto; -import org.tensorflow.framework.TensorShapeProto.Dim; - -import com.google.protobuf.InvalidProtocolBufferException; import deepimagej.tools.DijTensor; import deepimagej.tools.Index; -import deepimagej.tools.Log; -import deepimagej.tools.StartTensorflowService; -import ij.IJ; -import net.imagej.tensorflow.TensorFlowVersion; public class DeepLearningModel { - // Same as the tag used in export_saved_model in the Python code. - private static final String[] MODEL_TAGS = {"serve", "inference", "train", "eval", "gpu", "tpu"}; - private static final String DEFAULT_TAG = "serve"; - - - private static final String[] TF_MODEL_TAGS = {"tf.saved_model.tag_constants.SERVING", - "tf.saved_model.tag_constants.INFERENCE", - "tf.saved_model.tag_constants.TRAINING", - "tf.saved_model.tag_constants.EVAL", - "tf.saved_model.tag_constants.GPU", - "tf.saved_model.tag_constants.TPU"}; - - - private static final String[] SIGNATURE_CONSTANTS = {"serving_default", - "inputs", - "tensorflow/serving/classify", - "classes", - "scores", - "inputs", - "tensorflow/serving/predict", - "outputs", - "inputs", - "tensorflow/serving/regress", - "outputs", - "train", - "eval", - "tensorflow/supervised/training", - "tensorflow/supervised/eval"}; - - private static final String[] TF_SIGNATURE_CONSTANTS = {"tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.CLASSIFY_INPUTS", - "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", - "tf.saved_model.signature_constants.PREDICT_INPUTS", - "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", - "tf.saved_model.signature_constants.PREDICT_OUTPUTS", - "tf.saved_model.signature_constants.REGRESS_INPUTS", - "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", - "tf.saved_model.signature_constants.REGRESS_OUTPUTS", - "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", - "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME"}; - - // TODO remove this or the next method - public static SavedModelBundle loadTf(String path, String tag, Log log) { - log.print("load model from " + path); - - SavedModelBundle model = null; - try { - Runtime instance = Runtime.getRuntime(); - double a = instance.freeMemory() / (1024*1024.0); - model = SavedModelBundle.load(path, tag); - double b = instance.freeMemory() / (1024*1024.0); - System.out.println(b-a); - } - catch (Exception e) { - log.print("Exception in loading model " + path); - log.print(e.toString()); - log.print(e.getMessage()); - return null; - } - log.print("Loaded"); - return model; - } - - public static SavedModelBundle loadTfModel(String source, String modelTag) { - // Load the model with its correspondent tag - SavedModelBundle model; - try { - model = SavedModelBundle.load(source, modelTag); - } - catch (TensorFlowException e) { - System.out.println("The tag was incorrect"); - model = null; - } - return model; - } - - public static Object[] findTfTag(String source) { - // Obtain the model_tag needed to load the model. If none works, - // 'null' is returned - Object[] info = checkTfTags(source, DEFAULT_TAG); - return info; - } - - public static Object[] checkTfTags(String source, String tag) { - SavedModelBundle model = null; - Set sigKeys; - Object[] info = new Object[3]; - try { - model = SavedModelBundle.load(source, tag); - sigKeys = metaGraphsSet(model); - } - catch (TensorFlowException e) { - // If the tag does not work, try with the following existing tag - int tag_ind = Index.indexOf(MODEL_TAGS, tag); - if (tag_ind < MODEL_TAGS.length - 1) { - Object[] info2 = checkTfTags(source, MODEL_TAGS[tag_ind + 1]); - tag = (String) info2[0]; - sigKeys = (Set) info2[1]; - } - else { - // tag = null, the user will need to introduce it - tag = null; - sigKeys = null; - } - } - info[0] = tag; - info[1] = sigKeys; - info[2] = model; - return info; - } - - public static Set metaGraphsSet(SavedModelBundle model) { - byte[] byteGraph = model.metaGraphDef(); - // Obtain a mapping between the possible keys and their signature definitions - Map sig = null; - try { - sig = MetaGraphDef.parseFrom(byteGraph).getSignatureDefMap(); - } - catch (InvalidProtocolBufferException e) { - System.out.println("The model is not a correct SavedModel model"); - } - Set modelKeys = sig.keySet(); - return modelKeys; - } - - public static SignatureDef getSignatureFromGraph(SavedModelBundle model, String graph) { - byte[] byteGraph = model.metaGraphDef(); - SignatureDef sig = null; - try { - sig = MetaGraphDef.parseFrom(byteGraph).getSignatureDefOrThrow(graph); - } - catch (InvalidProtocolBufferException e) { - System.out.println("Invalid graph"); - } - return sig; - } - - public static int[] modelTfExitDimensions(SignatureDef sig, String entryName) { - // This method returns the dimensions of the tensor defined by - // the saved model. The method retrieves the tensor info and - // converts it into an array of integers. - TensorInfo entryInfo = sig.getOutputsOrThrow(entryName); - TensorShapeProto entryShape = entryInfo.getTensorShape(); - List listDim = entryShape.getDimList(); - int rank = listDim.size(); - int[] inputTensorSize = new int[rank]; - - for (int i = 0; i < rank; i++) { - inputTensorSize[i] = (int) listDim.get(i).getSize(); - } - return inputTensorSize; - } - - public static int[] modelTfEntryDimensions(SignatureDef sig, String entryName) { - // This method returns the dimensions of the tensor defined by - // the saved model. The method retrieves the tensor info and - // converts it into an array of integers. - TensorInfo entryInfo = sig.getInputsOrThrow(entryName); - TensorShapeProto entryShape = entryInfo.getTensorShape(); - List listDim = entryShape.getDimList(); - int rank = listDim.size(); - int[] inputTensorSize = new int[rank]; - - for (int i = 0; i < rank; i++) { - inputTensorSize[i] = (int) listDim.get(i).getSize(); - } - - return inputTensorSize; - } - - public static String[] returnTfOutputs(SignatureDef sig) { - - // Extract names from the model signature. - // The strings "input", "probabilities" and "patches" are meant to be - // in sync with the model exporter (export_saved_model()) in Python. - Map out = sig.getOutputsMap(); - Set outputKeys = out.keySet(); - String[] keysArray = outputKeys.toArray(new String[outputKeys.size()]); - return keysArray; - } - - public static String[] returnTfInputs(SignatureDef sig) { - - // Extract names from the model signature. - // The strings "input", "probabilities" and "patches" are meant to be - // in sync with the model exporter (export_saved_model()) in Python. - Map inp = sig.getInputsMap(); - Set inputKeys = inp.keySet(); - String[] keysArray = inputKeys.toArray(new String[inputKeys.size()]); - return keysArray; - } - public static int nChannelsOrSlices(DijTensor tensor, String channelsOrSlices) { // Find the number of channels or slices in the corresponding tensor String letter = ""; @@ -330,63 +113,6 @@ public static String nBatch(int[] dims, String inputForm) { return inBatch; } - public static String returnTfTag(String tag) { - String tfTag; - int tagInd = Index.indexOf(MODEL_TAGS, tag); - if (tagInd == -1) { - tfTag = tag; - } else { - tfTag = TF_MODEL_TAGS[tagInd]; - } - return tfTag; - } - - public static String returnStringTag(String tfTag) { - String tag; - int tagInd = Index.indexOf(TF_MODEL_TAGS, tfTag); - if (tagInd == -1) { - tag = tfTag; - } else { - tag = MODEL_TAGS[tagInd]; - } - return tag; - } - - public static Set returnTfSig(Set sig) { - Set tfSig = new HashSet<>(); - for (int i = 0; i < TF_SIGNATURE_CONSTANTS.length; i ++) { - if (sig.contains(SIGNATURE_CONSTANTS[i]) == true) { - tfSig.add(TF_SIGNATURE_CONSTANTS[i]); - } - } - if (tfSig.size() != sig.size()) { - tfSig = sig; - } - return tfSig; - } - - public static String returnStringSig(String tfSig) { - String sig; - int sigInd = Index.indexOf(TF_SIGNATURE_CONSTANTS, tfSig); - if (sigInd == -1) { - sig = tfSig; - } else { - sig = SIGNATURE_CONSTANTS[sigInd]; - } - return sig; - } - - public static String returnTfSig(String sig) { - String tfSig; - int tfSigInd = Index.indexOf(SIGNATURE_CONSTANTS, sig); - if (tfSigInd == -1) { - tfSig = sig; - } else { - tfSig = TF_SIGNATURE_CONSTANTS[tfSigInd]; - } - return tfSig; - } - // TODO group Tf and Pytorch methods regarding versions /* * Find if the CUDA and Tf versions are compatible @@ -444,233 +170,5 @@ public static String TensorflowCUDACompatibility(String tfVersion, String CUDAVe } return errMessage; } - - /* - * Get the Pytorch version number from the jar file - * The corresponding JAR is pytorch-native-auto-X.Y.Z.jar, - * where X.Y.Z is the version number - */ - public static String getPytorchVersion() { - String ptJni = ""; - - // TODO this only works for 1.7.0 - /*try { - URL resource = NativeHelper.class.getResource("NativeHelper.class"); - JarURLConnection connection = null; - connection = (JarURLConnection) resource.openConnection(); - ptJni = connection.getJarFileURL().getFile(); - } catch (Exception e) { - } - */ - ptJni = getLibPytorchJar(); - if (!ptJni.contains("jar")) - return ptJni; - - String ptVersion = getPytorchVersionFromJar(ptJni); - return ptVersion; - } - - /* - * Finds the directory where the Pytorch jar is - */ - public static String getLibPytorchJar() { - - // Search in the plugins folder - String ijDirectory = IJ.getDirectory("imagej") + File.separator; - // TODO remove - //ijDirectory = "C:\\Users\\Carlos(tfg)\\Desktop\\Fiji.app"; - - String pluginsDirectory = ijDirectory + File.separator + "plugins" + File.separator; - String pluginsJar = findPytorchJar(pluginsDirectory); - - // Search in the jars folder - String jarDirectory = ijDirectory + File.separator + "jars" + File.separator; - String jarsJar = findPytorchJar(jarDirectory); - - // Check that there is only one jar file present in both folders - if (jarsJar.equals(pluginsJar) && jarsJar.equals("")) { - return "-No Pytorch version found-"; - } else if (jarsJar.toLowerCase().contains("more than 1 version") || pluginsJar.toLowerCase().contains("more than 1 version")) { - return "-More than one Pytorch version present-"; - } else if (jarsJar.toLowerCase().contains("tensorflow") && jarsJar.toLowerCase().contains("tensorflow") && !jarsJar.equals(pluginsJar)) { - return "-The plugins and jars directories contains a different version of Pytorch each-"; - } - - // Find which of them is actually the TF jni jar - String pytorchJni = pluginsJar; - if (pytorchJni.equals("") == true) { - pytorchJni = jarsJar; - } - return pytorchJni; - } - - /* - * Finds the file corresponding to the tf jar - */ - public static String findPytorchJar(String folderDir) { - // Find the file libtensorflow_jni.jar - - // Name of the TF jni without the version - String jarName = "pytorch-native-auto"; - // Auxiliary variable to make sure we only have one TF jni - int nJars = 0; - String ptJar = ""; - - File folder = new File(folderDir); - File[] listOfFiles = folder.listFiles(); - - if (listOfFiles == null) - return ""; - - for (File file : listOfFiles) { - if (file.isFile() == true) { - String fileName = file.getAbsolutePath(); - if (fileName.indexOf(jarName) != -1) { - nJars ++; - ptJar = fileName; - } - } - } - - if (nJars == 0) { - ptJar = ""; - } else if (nJars >1) { - ptJar = "more than 1 version"; - } - - return ptJar; - } - - /* - * Get the version number from the jar file - */ - public static String getPytorchVersionFromJar(String jar) { - // Name of the TF jni without the version - jar = jar.toLowerCase(); - String flag = "pytorch-native-auto-"; - String jarExt = ".jar"; - String tfVersion = jar.substring(jar.lastIndexOf(flag) + flag.length(), jar.indexOf(jarExt)); - return tfVersion; - } - - /* - * Get the version number from the jar file - */ - public static String getTFVersion(boolean fiji) { - if (fiji) { - TensorFlowVersion tfVersion = StartTensorflowService.getTfService().getTensorFlowVersion(); - return tfVersion.getVersionNumber(); - } else { - return getTFVersionIJ(); - } - } - - /* - * Retrieves the TF version that is going to be used for the plugin. - * In order to do that, the method searches in two locations where the - *.jars might be: in the plugins folder or in the jars folder - */ - public static String getTFVersionIJ() { - String tfJni = ""; - try { - URL resource = ClassLoader.getSystemClassLoader().getResource("org/tensorflow/native"); - if (resource == null) - resource = IJ.getClassLoader().getResource("org/tensorflow/native"); - JarURLConnection connection = null; - connection = (JarURLConnection) resource.openConnection(); - tfJni = connection.getJarFileURL().getFile(); - } catch (Exception e) { - tfJni = getLibTfJar(); - if (!tfJni.contains("jar")) - return tfJni; - } - String tfVersion = getTfVersionFromJar(tfJni); - - if (tfVersion.contains("gpu")) { - tfVersion = tfVersion.substring(tfVersion.toLowerCase().indexOf("gpu_") + 5) + " GPU"; - } - return tfVersion; - } - - /* - * Finds the directory where the tf jar is - */ - public static String getLibTfJar() { - - // Search in the plugins folder - String ijDirectory = IJ.getDirectory("imagej") + File.separator; - // TODO remove - //ijDirectory = "C:\\Users\\Carlos(tfg)\\Videos\\Fiji.app"; - String pluginsDirectory = ijDirectory + File.separator + "plugins" + File.separator; - String pluginsJar = findTFJar(pluginsDirectory); - - // Search in the jars folder - String jarDirectory = ijDirectory + File.separator + "jars" + File.separator; - String jarsJar = findTFJar(jarDirectory); - - // Check that there is only one jar file present in both folders - if (jarsJar.equals(pluginsJar) && jarsJar.equals("")) { - return "-No Tensorflow version found-"; - } else if (jarsJar.toLowerCase().contains("more than 1 version") || pluginsJar.toLowerCase().contains("more than 1 version")) { - return "-More than one tensorflow version present-"; - } else if (jarsJar.toLowerCase().contains("tensorflow") && pluginsJar.toLowerCase().contains("tensorflow") && !jarsJar.equals(pluginsJar)) { - return "-The plugins and jars directories contains a different version of TF each-"; - } - - // Find which of them is actually the TF jni jar - String tfJni = pluginsJar; - if (tfJni.equals("") == true) { - tfJni = jarsJar; - } - return tfJni; - } - - /* - * Finds the file corresponding to the tf jar - */ - public static String findTFJar(String folderDir) { - // Find the file libtensorflow_jni.jar - - // Name of the TF jni without the version - String jarName = "libtensorflow_jni"; - // Auxiliary variable to make sure we only have one TF jni - int nJars = 0; - String tfJar = ""; - - File folder = new File(folderDir); - File[] listOfFiles = folder.listFiles(); - if (listOfFiles == null) - return ""; - - for (File file : listOfFiles) { - if (file.isFile() == true) { - String fileName = file.getAbsolutePath(); - if (fileName.indexOf(jarName) != -1) { - nJars ++; - tfJar = fileName; - } - } - } - - if (nJars == 0) { - tfJar = ""; - } else if (nJars >1) { - tfJar = "more than 1 version"; - } - - return tfJar; - } - - /* - * Get the version number from the jar file - */ - public static String getTfVersionFromJar(String jar) { - // Name of the TF jni without the version - jar = jar.toLowerCase(); - String flag = "libtensorflow_jni"; - String jarExt = ".jar"; - String tfVersion = jar.substring(jar.lastIndexOf(flag) + flag.length() + 1, jar.indexOf(jarExt)); - return tfVersion; - } } diff --git a/src/main/java/deepimagej/ImagePlus2Tensor.java b/src/main/java/deepimagej/ImagePlus2Tensor.java index 1870677b..9cfeeb07 100755 --- a/src/main/java/deepimagej/ImagePlus2Tensor.java +++ b/src/main/java/deepimagej/ImagePlus2Tensor.java @@ -44,19 +44,28 @@ package deepimagej; -import java.nio.FloatBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.stream.LongStream; -import org.tensorflow.Tensor; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.Shape; import deepimagej.exceptions.BatchSizeBiggerThanOne; import deepimagej.exceptions.IncorrectNumberOfDimensions; -import deepimagej.tools.ArrayOperations; import ij.IJ; import ij.ImagePlus; import ij.process.ImageProcessor; +import io.bioimage.modelrunner.tensor.Tensor; +import net.imglib2.Cursor; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.img.array.ArrayImg; +import net.imglib2.img.array.ArrayImgFactory; +import net.imglib2.type.Type; +import net.imglib2.type.numeric.NumericType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Util; +import net.imglib2.view.IntervalView; public class ImagePlus2Tensor { @@ -64,591 +73,237 @@ public class ImagePlus2Tensor { // TODO allow batch size != 1 // Methods to transform a DJL Pytorch and TF tensors into ImageJ ImagePlus - public static NDArray imPlus2tensor(NDManager manager, ImagePlus img, String form, String ptVersion){ + public static < T extends NumericType< T > & RealType< T > > RandomAccessibleInterval< T > imPlus2tensor(ImagePlus img, String form){ // Convert ImagePlus into tensor calling the corresponding // method depending on the dimensions of the required tensor // Find the number of dimensions of the tensor - int nDim = form.length(); - NDArray tensor = null; - if (nDim >= 2 && nDim <= 5) { - tensor = implus2NDArray(img, form, manager, ptVersion); - } - return tensor; - } - - /* - * Check that if the DJL Pytorch version is older than - * version 1.7.0 - */ - public static boolean olderThanPytorch170(String ptVersion) { - boolean older = true; - try { - int firstDot = ptVersion.indexOf("."); - int secondDot = ptVersion.substring(firstDot + 1).indexOf(".") + firstDot + 1; - int majorVersion = Integer.parseInt(ptVersion.substring(0, firstDot)); - int minorVersion = Integer.parseInt(ptVersion.substring(firstDot + 1, secondDot)); - if ((majorVersion >= 1 && minorVersion >= 7) || majorVersion > 1) { - older = false; - } - - } catch(Exception ex) { - if (!ptVersion.contains("1.4.") && !ptVersion.contains("1.5.") && !ptVersion.contains("1.6.")) { - older = false; - } - } - return older; - } - - public static NDArray implus2NDArray(ImagePlus img, String form, NDManager manager, String ptVersion){ - // Create a float array of four dimensions out of an - // ImagePlus object - float[] matImage; - // Initialise ImageProcessor variable used later - ImageProcessor ip; - int[] dims = img.getDimensions(); - int xSize = dims[0]; - int ySize = dims[1]; - int cSize = dims[2]; - int zSize = dims[3]; - // TODO allow different batch sizes - int batch = 1; - int[] tensorDims = new int[] {1, 1, 1, 1, 1}; - // Create aux variable to indicate - // if it is channels one of the dimensions of - // the tensor or it is the batch size - int fBatch = -1; - int fChannel = -1; - int fDepth = -1; - int fWidth = -1; - int fHeight = -1; - - // For DJL Pytorch versions <1.7.0, the batch size is not included in the tensor - long[] arrayShape = new long[form.length() - 1]; - boolean old = olderThanPytorch170(ptVersion); - if (form.indexOf("B") != -1) { - fBatch = form.indexOf("B"); - tensorDims[fBatch] = batch; - // For DJL Pytorch versions >=1.7.0, the batch size is included in the tensor - if (!old) { - arrayShape = new long[form.length()]; - arrayShape[fBatch] = (long) batch; - } else { - String auxForm = form.substring(0, fBatch) + form.substring(fBatch + 1); - IJ.log("WARNING: DJL Pytorch versions <=1.6.0 do not allow definition of the batch size."); - IJ.log("WARNING: Image input tensor dimension organization has changed: " + form + " --> " + auxForm); - } - } else { - arrayShape = new long[form.length()]; - fBatch = form.length(); - form += "B"; - } - if (form.indexOf("Y") != -1) { - fHeight = form.indexOf("Y"); - tensorDims[fHeight] = ySize; - if (fBatch != -1 && fHeight > fBatch && old) - arrayShape[fHeight - 1] = (long) ySize; - else - arrayShape[fHeight] = (long) ySize; - } else { - fHeight = form.length(); - form += "Y"; - } - if (form.indexOf("X") != -1) { - fWidth = form.indexOf("X"); - tensorDims[fWidth] = xSize; - if (fBatch != -1 && fWidth > fBatch && old) - arrayShape[fWidth - 1] = (long) xSize; - else - arrayShape[fWidth] = (long) xSize; - } else { - fWidth = form.length(); - form += "X"; - } - if (form.indexOf("C") != -1) { - fChannel = form.indexOf("C"); - tensorDims[fChannel] = cSize; - if (fBatch != -1 && fChannel > fBatch && old) - arrayShape[fChannel - 1] = (long) cSize; - else - arrayShape[fChannel] = (long) cSize; - } else { - fChannel = form.length(); - form += "C"; - } - if (form.indexOf("Z") != -1) { - fDepth = form.indexOf("Z"); - tensorDims[fDepth] = zSize; - if (fBatch != -1 && fDepth > fBatch && old) - arrayShape[fDepth - 1] = (long) zSize; - else - arrayShape[fDepth] = (long) zSize; - } else { - fDepth = form.length(); - form += "Z"; - } - matImage = new float[tensorDims[0] * tensorDims[1] * tensorDims[2] * tensorDims[3] * tensorDims[4]]; + int[] tensorDimOrder = Tensor.convertToTensorDimOrder(form); - // Make sure the array is written from last dimension to first dimension. - // For example, for CYX we first iterate over all the X, then over the Y and then - // over the C - int[] auxCounter = new int[5]; - int pos = 0; - for (int t0 = 0; t0 < tensorDims[0]; t0 ++) { - auxCounter[0] = t0; - for (int t1 = 0; t1 < tensorDims[1]; t1 ++) { - auxCounter[1] = t1; - for (int t2 = 0; t2 < tensorDims[2]; t2 ++) { - auxCounter[2] = t2; - for (int t3 = 0; t3 < tensorDims[3]; t3 ++) { - auxCounter[3] = t3; - for (int t4 = 0; t4 < tensorDims[4]; t4 ++) { - auxCounter[4] = t4; - - img.setPositionWithoutUpdate(auxCounter[fChannel] + 1, auxCounter[fDepth] + 1, 1); - ip = img.getProcessor(); - matImage[pos ++] = ip.getPixelValue(auxCounter[fWidth], auxCounter[fHeight]); - } - } - } - } - } - FloatBuffer outBuff = FloatBuffer.wrap(matImage); - NDArray tensor = manager.create(matImage, new Shape(arrayShape)); - return tensor; - } - - public static Tensor implus2TensorFloat(ImagePlus img, String form){ - // Create a float array of four dimensions out of an - // ImagePlus object - float[] matImage; - // Initialise ImageProcessor variable used later - ImageProcessor ip; - int[] dims = img.getDimensions(); - int xSize = dims[0]; - int ySize = dims[1]; - int cSize = dims[2]; - int zSize = dims[3]; // TODO allow different batch sizes - int batch = 1; - int[] tensorDims = new int[] {1, 1, 1, 1, 1}; - // Create aux variable to indicate - // if it is channels one of the dimensions of - // the tensor or it is the batch size - int fBatch = -1; - int fChannel = -1; - int fDepth = -1; - int fWidth = -1; - int fHeight = -1; - - long[] arrayShape = new long[form.length()]; - if (form.indexOf("B") != -1) { - fBatch = form.indexOf("B"); - tensorDims[fBatch] = batch; - arrayShape[fBatch] = (long) batch; - } else { - fBatch = form.length(); - form += "B"; - } - if (form.indexOf("Y") != -1) { - fHeight = form.indexOf("Y"); - tensorDims[fHeight] = ySize; - arrayShape[fHeight] = (long) ySize; - } else { - fHeight = form.length(); - form += "Y"; - } - if (form.indexOf("X") != -1) { - fWidth = form.indexOf("X"); - tensorDims[fWidth] = xSize; - arrayShape[fWidth] = (long) xSize; - } else { - fWidth = form.length(); - form += "X"; - } - if (form.indexOf("C") != -1) { - fChannel = form.indexOf("C"); - tensorDims[fChannel] = cSize; - arrayShape[fChannel] = (long) cSize; - } else { - fChannel = form.length(); - form += "C"; - } - if (form.indexOf("Z") != -1) { - fDepth = form.indexOf("Z"); - tensorDims[fDepth] = zSize; - arrayShape[fDepth] = (long) zSize; - } else { - fDepth = form.length(); - form += "Z"; - } - matImage = new float[tensorDims[0] * tensorDims[1] * tensorDims[2] * tensorDims[3] * tensorDims[4]]; - + // Create a cursor + int[] tensorDims = getTensorCompleteTensorDimensions(img.getDimensions(), tensorDimOrder); + // Find the correspondence between the sequence axes order and + // the tensor axes order + int[] orderCorrespondence = getSequenceDimOrder(tensorDimOrder); // Make sure the array is written from last dimension to first dimension. // For example, for CYX we first iterate over all the X, then over the Y and then // over the C int[] auxCounter = new int[5]; - int pos = 0; - for (int t0 = 0; t0 < tensorDims[0]; t0 ++) { - auxCounter[0] = t0; - for (int t1 = 0; t1 < tensorDims[1]; t1 ++) { - auxCounter[1] = t1; - for (int t2 = 0; t2 < tensorDims[2]; t2 ++) { - auxCounter[2] = t2; - for (int t3 = 0; t3 < tensorDims[3]; t3 ++) { - auxCounter[3] = t3; - for (int t4 = 0; t4 < tensorDims[4]; t4 ++) { - auxCounter[4] = t4; - - img.setPositionWithoutUpdate(auxCounter[fChannel] + 1, auxCounter[fDepth] + 1, 1); - ip = img.getProcessor(); - matImage[pos ++] = ip.getPixelValue(auxCounter[fWidth], auxCounter[fHeight]); - } - } - } - } - } - FloatBuffer outBuff = FloatBuffer.wrap(matImage); - - Tensor tensor = Tensor.create(arrayShape, outBuff); - return tensor; - } - - - /////////// Methods to transform an NDArray tensor into an ImageJ ImagePlus - - - public static ImagePlus NDArray2ImagePlus(NDArray tensor, String form, String name, String ptVersion) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne{ - // This method copies the information from the tensor to a matrix. At first only works - // if the batch size is 1 - - // ImagePlus dimensions in the TensorFlow style. In this case we consider B as T, - // as for the moment both are going to be 1 - - ImagePlus imPlus = null; - long[] tensorShape = tensor.getShape().getShape(); - boolean old = olderThanPytorch170(ptVersion); - int batchIndex = form.indexOf("B"); - - // TODO should batch be eliminated always or only when the dimensions are incorrect - if (old && batchIndex != -1) { - String oldForm = "" + form; - form = oldForm.substring(0, batchIndex) + oldForm.substring(batchIndex + 1); - IJ.log("WARNING: DJL Pytorch versions <=1.6.0 do not allow definition of the batch size."); - IJ.log("WARNING: Output tensor '" + name + "' dimension organization has changed: " + oldForm + " --> " + form); - } - - if (tensorShape.length != form.length()) - throw new IncorrectNumberOfDimensions(tensorShape, form, name); - int[] completeTensorShape = longShape6(tensorShape); - int[] imageDims = {1, 1, 1, 1, 1}; - - // TODO add possibility of batch>1 - if (batchIndex != -1 && tensorShape[batchIndex] > 1) - throw new BatchSizeBiggerThanOne(tensorShape, form, name); - - int fBatch; - if (form.indexOf("B") != -1) { - fBatch = form.indexOf("B"); - imageDims[4] = (int) tensorShape[fBatch]; - } else { - fBatch = form.length(); - form += "B"; - } - int fHeight; - if (form.indexOf("Y") != -1) { - fHeight = form.indexOf("Y"); - imageDims[1] = (int) tensorShape[fHeight]; - } else { - fHeight = form.length(); - form += "Y"; - } - int fWidth; - if (form.indexOf("X") != -1) { - fWidth = form.indexOf("X"); - imageDims[0] = (int) tensorShape[fWidth]; - } else { - fWidth = form.length(); - form += "X"; - } - int fChannel; - if (form.indexOf("C") != -1) { - fChannel = form.indexOf("C"); - imageDims[2] = (int) tensorShape[fChannel]; - } else { - fChannel = form.length(); - form += "C"; - } - int fDepth; - if (form.indexOf("Z") != -1) { - fDepth = form.indexOf("Z"); - imageDims[3] = (int) tensorShape[fDepth]; - } else { - fDepth = form.length(); - form += "Z"; - } - - float[] flatImageArray = tensor.toFloatArray(); - double[][][][][] matImage = new double[imageDims[0]][imageDims[1]][imageDims[2]][imageDims[3]][imageDims[4]]; - - int pos = 0; - int[] auxInd = {0, 0, 0, 0, 0}; - for (int i0 = 0; i0 < completeTensorShape[0]; i0 ++) { - auxInd[0] = i0; - for (int i1 = 0; i1 < completeTensorShape[1]; i1 ++) { - auxInd[1] = i1; - for (int i2 = 0; i2 < completeTensorShape[2]; i2 ++) { - auxInd[2] = i2; - for (int i3 = 0; i3 < completeTensorShape[3]; i3 ++) { - auxInd[3] = i3; - for (int i4 = 0; i4 < completeTensorShape[4]; i4 ++) { - auxInd[4] = i4; - matImage[auxInd[fWidth]][auxInd[fHeight]][auxInd[fChannel]][auxInd[fDepth]][auxInd[fBatch]] = (double) flatImageArray[pos ++]; - } - } - } - } - } - imPlus = ArrayOperations.convertArrayToImagePlus(matImage, imageDims); - - return imPlus; - } + final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>(new FloatType()); + long[] tensorSize = LongStream.range(0, tensorDimOrder.length).map(i -> tensorDims[(int) i]).toArray(); + final Img< FloatType > tensor = factory.create( tensorSize ); + Cursor tensorCursor = tensor.cursor(); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] position = tensorCursor.positionAsLongArray(); + for (int i = 0; i < position.length; i ++) { + auxCounter[i] = (int) position[i]; + } + // TODO remove + int[] icyInd = {auxCounter[orderCorrespondence[0]], auxCounter[orderCorrespondence[1]], auxCounter[orderCorrespondence[2]], auxCounter[orderCorrespondence[3]], auxCounter[orderCorrespondence[4]]}; + + img.setPositionWithoutUpdate(icyInd[2] + 1, icyInd[3] + 1, icyInd[4] + 1); + ImageProcessor ip = img.getProcessor(); + float val = ip.getPixelValue(icyInd[0], icyInd[1]); + tensorCursor.get().set(val); + } + return (RandomAccessibleInterval) tensor; + } // TODO make specific for different types - public static ImagePlus tensor2ImagePlus(Tensor tensor, String form, String name) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne{ + public static < T extends NumericType< T > & RealType< T > > ImagePlus tensor2ImagePlus(RandomAccessibleInterval data, String form) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne{ // This method copies the information from the tensor to a matrix. At first only works // if the batch size is 1 // ImagePlus dimensions in the TensorFlow style. In this case we consider B as T, // as for the moment both are going to be 1 - ImagePlus imPlus = null; - long[] tensorShape = tensor.shape(); - if (tensorShape.length != form.length()) - throw new IncorrectNumberOfDimensions(tensorShape, form, name); - int[] completeTensorShape = longShape6(tensorShape); - int[] imageDims = {1, 1, 1, 1, 1}; - - int batchIndex = form.indexOf("B"); - - // TODO add possibility of batch>1 - if (batchIndex != -1 && tensorShape[batchIndex] > 1) - throw new BatchSizeBiggerThanOne(tensorShape, form, name); - - int fBatch; - if (form.indexOf("B") != -1) { - fBatch = form.indexOf("B"); - imageDims[4] = (int) tensorShape[fBatch]; - } else { - fBatch = form.length(); - form += "B"; - } - int fHeight; - if (form.indexOf("Y") != -1) { - fHeight = form.indexOf("Y"); - imageDims[1] = (int) tensorShape[fHeight]; - } else { - fHeight = form.length(); - form += "Y"; - } - int fWidth; - if (form.indexOf("X") != -1) { - fWidth = form.indexOf("X"); - imageDims[0] = (int) tensorShape[fWidth]; - } else { - fWidth = form.length(); - form += "X"; - } - int fChannel; - if (form.indexOf("C") != -1) { - fChannel = form.indexOf("C"); - imageDims[2] = (int) tensorShape[fChannel]; - } else { - fChannel = form.length(); - form += "C"; - } - int fDepth; - if (form.indexOf("Z") != -1) { - fDepth = form.indexOf("Z"); - imageDims[3] = (int) tensorShape[fDepth]; - } else { - fDepth = form.length(); - form += "Z"; - } - - float[] flatImageArray = new float[imageDims[0] * imageDims[1] * imageDims[2] * imageDims[3] * imageDims[4]]; - - FloatBuffer outBuff = FloatBuffer.wrap(flatImageArray); - tensor.writeTo(outBuff); - double[][][][][] matImage = new double[imageDims[0]][imageDims[1]][imageDims[2]][imageDims[3]][imageDims[4]]; - - int pos = 0; + // TODO adapt to several batch sizes + long[] dataShape = data.dimensionsAsLongArray(); + + if (dataShape.length != form.length()) + throw new IllegalArgumentException("Tensor has " + dataShape.length + " dimensions " + + "whereas the specified axes have " + form.length() + " (" + form + ")."); + int[] axesOrder = Tensor.convertToTensorDimOrder(form); + Type dtype = Util.getTypeFromInterval(data); + // Check if the axes order is valid + checkTensorDimOrder(dataShape, axesOrder); + // Add missing dimensions to the tensor axes order. The missing dimensions + // are added at the end + int[] completeDimOrder = completeImageDimensions(axesOrder); + // Get the order of the tensor with respect to the axes of an ImageJ sequence + int[] seqDimOrder = getSequenceDimOrder(completeDimOrder); + // GEt the size of the tensor for every dimension existing in an Icy sequence + int[] seqSize = getSequenceSize(axesOrder, dataShape); + // Create result sequence + ImagePlus sequence = IJ.createHyperStack("output", seqSize[0], seqSize[1], seqSize[2], seqSize[3], + seqSize[4], 32); + // Create an array with the shape of the tensor for every dimension in Icy + // REcall that Icy axes are organized as [xyzbc] but in this plugin + // to keep the convention with ImageJ and Fiji, we will always act as + // they were [xyczb]. That is why in the following command, after + // tensorSize[seqDimOrder[1]], it goes tensorSize[seqDimOrder[4]], + // instead of tensorSize[seqDimOrder[2]], because seqSize uses + // Icy axes, but seqDimOrder refers to the tensor from ImageJ axes + int[] tensorShape = new int[5]; + tensorShape[seqDimOrder[0]] = seqSize[0]; tensorShape[seqDimOrder[1]] = seqSize[1]; + tensorShape[seqDimOrder[2]] = seqSize[2]; tensorShape[seqDimOrder[3]] = seqSize[3]; + tensorShape[seqDimOrder[4]] = seqSize[4]; int[] auxInd = {0, 0, 0, 0, 0}; - for (int i0 = 0; i0 < completeTensorShape[0]; i0 ++) { - auxInd[0] = i0; - for (int i1 = 0; i1 < completeTensorShape[1]; i1 ++) { - auxInd[1] = i1; - for (int i2 = 0; i2 < completeTensorShape[2]; i2 ++) { - auxInd[2] = i2; - for (int i3 = 0; i3 < completeTensorShape[3]; i3 ++) { - auxInd[3] = i3; - for (int i4 = 0; i4 < completeTensorShape[4]; i4 ++) { - auxInd[4] = i4; - matImage[auxInd[fWidth]][auxInd[fHeight]][auxInd[fChannel]][auxInd[fDepth]][auxInd[fBatch]] = (double) flatImageArray[pos ++]; - } - } - } - } - } - imPlus = ArrayOperations.convertArrayToImagePlus(matImage, imageDims); - return imPlus; - } - - private static int[] longShape6(long[] shape) { - // First convert add the needed entries with value 1 to the array - // until its length is 5 - int[] f_shape = { 1, 1, 1, 1, 1, 1 }; - for (int i = 0; i < shape.length; i++) { - f_shape[i] = (int) shape[i]; - } - return f_shape; + Cursor tensorCursor; + if (data instanceof IntervalView) + tensorCursor = ((IntervalView) data).cursor(); + else if (data instanceof Img) + tensorCursor = ((Img) data).cursor(); + else if (data instanceof ArrayImg) + tensorCursor = ((ArrayImg) data).cursor(); + else + throw new IllegalArgumentException("First parameter has to be an instance of " + Img.class + + " or " + IntervalView.class + " or " + ArrayImg.class); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + for (int i = 0; i < cursorPos.length; i ++) { + auxInd[i] = (int) cursorPos[i]; + } + float val = tensorCursor.get().getRealFloat(); + int[] icyInd = {auxInd[seqDimOrder[0]], auxInd[seqDimOrder[1]], auxInd[seqDimOrder[2]], auxInd[seqDimOrder[3]], auxInd[seqDimOrder[4]]}; + sequence.setPositionWithoutUpdate(icyInd[2] + 1, icyInd[3] + 1, icyInd[4] + 1); + ImageProcessor ip = sequence.getProcessor(); + ip.putPixelValue(icyInd[0], icyInd[1], (float) val); + } + return sequence; } - - // Convert image plus into int array - - /* - * Method that gets an long[] array with the shape of the tensor/image + + /** + * Create an array where each position corresponds to the size + * of the tensor that will be created. The array has all the possible + * dimensions of a sequence. For example for a sequence with X->255, + * Y->256, C->3, Z->1, B->1, and axes order "BYXC" (which will be represented + * by the variable 'arrayDimOrder' as [4, 1, 0, 2]), the resulting array + * would be [1, 256, 256, 3, 1] + * @param sequence: image from which the tensor will be created + * @param arrayDimOrder: axes order of the tensor + * @return array with the size of each tensor in the corresponding dimension */ - public static long[] getTensorShape(ImagePlus img, String form) { - int[] dims = img.getDimensions(); - int xSize = dims[0]; - int ySize = dims[1]; - int cSize = dims[2]; - int zSize = dims[3]; - // TODO allow different batch sizes - int batch = 1; - // Create aux variable to indicate - // if it is channels one of the dimensions of - // the tensor or it is the batch size - int fBatch = -1; - int fChannel = -1; - int fDepth = -1; - int fWidth = -1; - int fHeight = -1; + private static int[] getTensorCompleteTensorDimensions(int[] dims, int[] arrayDimOrder) + { + // Map the dimensions integer (ie x->0, y->1, c->2, z->3, t->4) + HashMap dimsMap = new HashMap(); + dimsMap.put(0, dims[0]); + dimsMap.put(1, dims[1]); + dimsMap.put(2, dims[2]); + dimsMap.put(3, dims[3]); + dimsMap.put(4, dims[4]); + int[] tensorDims = new int[] {1, 1, 1, 1, 1}; + for (int i = 0; i < arrayDimOrder.length; i ++) + tensorDims[i] = dimsMap.get(arrayDimOrder[i]); + return tensorDims; + } - long[] arrayShape = new long[form.length()];; - if (form.indexOf("B") != -1) { - fBatch = form.indexOf("B"); - arrayShape[fBatch] = (long) batch; - } - if (form.indexOf("Y") != -1) { - fHeight = form.indexOf("Y"); - arrayShape[fHeight] = (long) ySize; - } - if (form.indexOf("X") != -1) { - fWidth = form.indexOf("X"); - arrayShape[fWidth] = (long) xSize; - } - if (form.indexOf("C") != -1) { - fChannel = form.indexOf("C"); - arrayShape[fChannel] = (long) cSize; - } - if (form.indexOf("Z") != -1) { - fDepth = form.indexOf("Z"); - arrayShape[fDepth] = (long) zSize; - } - return arrayShape; - } + /** + * Computes the sequence dimension order with respect to the tensor dimensions. + * + * @param tensorDimOrder + * The Tensor dimension order. + * @return The sequence dimension order. + */ + private static int[] getSequenceDimOrder(int[] tensorDimOrder) + { + tensorDimOrder = tensorDimOrderAllDims(tensorDimOrder); + int[] imgDimOrder = new int[] {-1, -1, -1, -1, -1}; + for (int i = 0; i < tensorDimOrder.length; i++) + { + imgDimOrder[tensorDimOrder[i]] = i; + } + return imgDimOrder; + } + + /** + * Create a dimensions (axes) order array that contains all the possible dimensions, + * adding the ones missing from the tensor at the end of the array + * @param tensorDimOrder + * the tensor axes order in array form + * @return the tensor axes order but with all the possible dims + */ + private static int[] tensorDimOrderAllDims(int[] tensorDimOrder) { + int[] longDimOrder = new int[5]; + // Auxiliary array with dimensions ordered + int[] auxArr = new int[] {0,1,2,3,4}; + int i; + for (i = 0; i < tensorDimOrder.length; i ++) { + longDimOrder[i] = tensorDimOrder[i]; + auxArr[tensorDimOrder[i]] = -1; + } + + for (int aa : auxArr) { + if (aa != -1) + longDimOrder[i ++] = aa; + } + return longDimOrder; + } - /* - * Method that converts ImagePLus into float[] array. + /** + * Check that the dimensions order provided is compatible with + * the output array given. If it is not, the method throws an exception, + * if it is, nothing happens + * @param dataShape + * shape of the data array + * @param tensorDimOrder + * dimensions (axes) order given + * @throws IllegalArgumentException if the dimensions do not have the same length */ + private static void checkTensorDimOrder(long[] dataShape, int[] tensorDimOrder) + throws IllegalArgumentException + { + if (tensorDimOrder.length != dataShape.length) + { + throw new IllegalArgumentException( + "Tensor dim order array length is different than number of dimensions in tensor (" + + tensorDimOrder.length + " != " + dataShape.length + ")"); + } + } + + // TODO improve efficiency + /** + * Add to the tensor axes order array the dimensions missing, + * the dimensions are always added at the end. + * For example, for a tensor with axes [byxc], its tensorDimOrder + * would be transformed from [4,1,0,2] to [4,1,0,2,3] ([byxcz]) + * @param tensorDimOrder; axes order of the tensor + * @return new axes order with dimensions at the end + */ + private static int[] completeImageDimensions(int[] tensorDimOrder) { + int nTotalImageDims = 5; + int nTensorDims = tensorDimOrder.length; + int missingDims = nTotalImageDims - nTensorDims; + int[] missingDimsArr = new int[missingDims]; + int c = 0; + for (int ii : new int[] {0, 1, 2, 3, 4}) { + if (Arrays.stream(tensorDimOrder).noneMatch(i -> i == ii)) + missingDimsArr[c ++] = ii; + } + int[] completeDims = new int[nTotalImageDims]; + System.arraycopy(tensorDimOrder, 0, completeDims, 0, tensorDimOrder.length); + System.arraycopy(missingDimsArr, 0, completeDims, tensorDimOrder.length, missingDimsArr.length); + return completeDims; + } - // TODO use this as basis for the implus2tensor and implus2ndarray or remove - public static float[] implus2IntArray(ImagePlus img, String form){ - // Create a float array of four dimensions out of an - // ImagePlus object - float[] matImage; - // Initialise ImageProcessor variable used later - ImageProcessor ip; - int[] dims = img.getDimensions(); - int xSize = dims[0]; - int ySize = dims[1]; - int cSize = dims[2]; - int zSize = dims[3]; - // TODO allow different batch sizes - int batch = 1; - int[] tensorDims = new int[] {1, 1, 1, 1, 1}; - // Create aux variable to indicate - // if it is channels one of the dimensions of - // the tensor or it is the batch size - int fBatch = -1; - int fChannel = -1; - int fDepth = -1; - int fWidth = -1; - int fHeight = -1; - - if (form.indexOf("B") != -1) { - fBatch = form.indexOf("B"); - tensorDims[fBatch] = batch; - } else { - fBatch = form.length(); - form += "B"; - } - if (form.indexOf("Y") != -1) { - fHeight = form.indexOf("Y"); - tensorDims[fHeight] = ySize; - } else { - fHeight = form.length(); - form += "Y"; - } - if (form.indexOf("X") != -1) { - fWidth = form.indexOf("X"); - tensorDims[fWidth] = xSize; - } else { - fWidth = form.length(); - form += "X"; - } - if (form.indexOf("C") != -1) { - fChannel = form.indexOf("C"); - tensorDims[fChannel] = cSize; - } else { - fChannel = form.length(); - form += "C"; - } - if (form.indexOf("Z") != -1) { - fDepth = form.indexOf("Z"); - tensorDims[fDepth] = zSize; - } else { - fDepth = form.length(); - form += "Z"; - } - matImage = new float[tensorDims[0] * tensorDims[1] * tensorDims[2] * tensorDims[3] * tensorDims[4]]; - - // Make sure the array is written from last dimension to first dimension. - // For example, for CYX we first iterate over all the X, then over the Y and then - // over the C - int[] auxCounter = new int[5]; - int pos = 0; - for (int t0 = 0; t0 < tensorDims[0]; t0 ++) { - auxCounter[0] = t0; - for (int t1 = 0; t1 < tensorDims[1]; t1 ++) { - auxCounter[1] = t1; - for (int t2 = 0; t2 < tensorDims[2]; t2 ++) { - auxCounter[2] = t2; - for (int t3 = 0; t3 < tensorDims[3]; t3 ++) { - auxCounter[3] = t3; - for (int t4 = 0; t4 < tensorDims[4]; t4 ++) { - auxCounter[4] = t4; - - img.setPositionWithoutUpdate(auxCounter[fChannel] + 1, auxCounter[fDepth] + 1, 1); - ip = img.getProcessor(); - matImage[pos ++] = ip.getPixelValue(auxCounter[fWidth], auxCounter[fHeight]); - } - } - } - } - } - - return matImage; - } + /** + * Get the size of each of the dimensions expressed in an array that + * follows the ImageJ axes order -> xyczt + * @param seqDimOrder + * order of the dimensions of the Icy sequence with respect to the tensor + * @param shape + * shape of the dimensions of the data + * @return array containing the size for each dimension + */ + private static int[] getSequenceSize(int[] seqDimOrder, long[] shape) + { + int[] dims = new int[] {1, 1, 1, 1, 1}; + for (int i = 0; i < seqDimOrder.length; i ++) { + dims[seqDimOrder[i]] = (int) shape[i]; + } + return dims; + } } diff --git a/src/main/java/deepimagej/InstallerDialog.java b/src/main/java/deepimagej/InstallerDialog.java index 0e5e82e6..dc9d92b4 100755 --- a/src/main/java/deepimagej/InstallerDialog.java +++ b/src/main/java/deepimagej/InstallerDialog.java @@ -295,7 +295,7 @@ public long getFileSize() { /* * Get file size of the online model */ - public long getFileSize(URL url) { + public static long getFileSize(URL url) { HttpURLConnection conn = null; try { conn = (HttpURLConnection) url.openConnection(); diff --git a/src/main/java/deepimagej/Parameters.java b/src/main/java/deepimagej/Parameters.java index 2629525a..60446624 100755 --- a/src/main/java/deepimagej/Parameters.java +++ b/src/main/java/deepimagej/Parameters.java @@ -59,6 +59,7 @@ import deepimagej.tools.DijTensor; import deepimagej.tools.YAMLUtils; +import deepimagej.tools.weights.ModelWeight; import ij.ImagePlus; public class Parameters { @@ -230,6 +231,7 @@ public class Parameters { public String framework = ""; public String tfSource = null; public String ptSource = null; + public String onnxSource = null; public String description = null; public String git_repo = null; @@ -277,6 +279,7 @@ public class Parameters { * there is a Pytorch model */ public String ptSha256 = ""; + public String onnxSha256 = ""; /* * Specifies if the folder contains a Bioimage Zoo model */ @@ -297,6 +300,10 @@ public class Parameters { * weights folder */ public String selectedModelPath = ""; + /** + * Weights object specified in the YAml file + */ + public ModelWeight weights; public Parameters(boolean valid, String path, boolean isDeveloper) { // If the model is not valid or we are in the developer plugin, @@ -397,10 +404,12 @@ public Parameters(boolean valid, String path, boolean isDeveloper) { git_repo = (String) "" + obj.get("git_repo"); LinkedHashMap weights = (LinkedHashMap) obj.get("weights"); + this.weights = ModelWeight.build(weights); // Look for the valid weights tags Set weightFormats = weights.keySet(); boolean tf = false; boolean pt = false; + boolean onnx = false; for (String format : weightFormats) { if (format.equals("tensorflow_saved_model_bundle")) { tf = true; @@ -453,6 +462,11 @@ else if (!str.contentEquals(defaultFlag)) ptAttachmentsNotIncluded.add(str); } } + } else if (format.equals("onnx")) { + onnx = true; + HashMap onnxMap = ((HashMap) weights.get("onnx")); + onnxSource = (String) onnxMap.get("source"); + } } @@ -478,7 +492,11 @@ else if (!str.contentEquals(defaultFlag)) else aux = ((LinkedHashMap) weights.get("torchscript")).get("sha256"); ptSha256 = "" + (String) aux; - } else if (!tf && !pt) { + } else if (onnx){ + framework = "onnx"; + onnxSha256 = (String) "" + ((LinkedHashMap) weights.get("onnx")).get("sha256"); + + } else if (!tf && !pt && !onnx) { completeConfig = false; return; } diff --git a/src/main/java/deepimagej/RunnerPt.java b/src/main/java/deepimagej/RunnerDL.java similarity index 87% rename from src/main/java/deepimagej/RunnerPt.java rename to src/main/java/deepimagej/RunnerDL.java index 15452246..b7fae769 100755 --- a/src/main/java/deepimagej/RunnerPt.java +++ b/src/main/java/deepimagej/RunnerDL.java @@ -44,22 +44,12 @@ package deepimagej; -import java.nio.FloatBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.concurrent.Callable; -import org.tensorflow.Tensor; - -import ai.djl.engine.EngineException; -import ai.djl.inference.Predictor; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.Shape; -import ai.djl.repository.zoo.ZooModel; import deepimagej.exceptions.BatchSizeBiggerThanOne; import deepimagej.exceptions.IncorrectNumberOfDimensions; import deepimagej.tools.ArrayOperations; @@ -71,8 +61,14 @@ import ij.IJ; import ij.ImagePlus; import ij.measure.ResultsTable; +import io.bioimage.modelrunner.model.Model; +import io.bioimage.modelrunner.tensor.Tensor; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.util.Util; -public class RunnerPt implements Callable> { +public class RunnerDL < T extends RealType< T > & NativeType< T > > implements Callable> { private HashMap inputMap; private DeepImageJ dp; @@ -82,7 +78,7 @@ public class RunnerPt implements Callable> { private int totalPatch = 0; public String error = ""; - public RunnerPt(DeepImageJ dp, RunnerProgress rp,HashMap inputMap, Log log) { + public RunnerDL(DeepImageJ dp, RunnerProgress rp,HashMap inputMap, Log log) { this.dp = dp; this.rp = rp; this.log = log; @@ -91,7 +87,7 @@ public RunnerPt(DeepImageJ dp, RunnerProgress rp,HashMap inputMap } @Override - public HashMap call() { + public HashMap call() { if (rp != null ) { rp.setInfoTag("applyModel"); @@ -104,7 +100,7 @@ public HashMap call() { Parameters params = dp.params; // Load the model first - ZooModel model = dp.getTorchModel(); + Model model = dp.getModel(); if (log.getLevel() >= 1) log.print("model " + (model == null)); @@ -141,14 +137,11 @@ public HashMap call() { + "the preprocessing but it is not."; IJ.error(error); return null; - } else if (tensorVal instanceof Tensor) { - parameterMap.put(tensor.name, (Tensor) tensorVal); - } else if (tensorVal instanceof NDArray) { - parameterMap.put(tensor.name, (NDArray) tensorVal); + } else if (tensorVal instanceof Tensor) { + parameterMap.put(tensor.name, (Tensor) tensorVal); } else { // TODO improve error message and review what can be a preprocessing output - error = "Output of the preprocessing should be either a Tensor object" - + " or a NDArray object"; + error = "Output of the preprocessing should be a Biimage.io Tensor"; IJ.error(error); return null; } @@ -305,14 +298,12 @@ public HashMap call() { if (log.getLevel() >= 1) log.print("start " + npx + "x" + npy); - - NDList inputTensors = new NDList(); for (int i = 0; i < npx; i++) { for (int j = 0; j < npy; j++) { for (int z = 0; z < npz; z++) { // TODO reduce this mega big loop to something more modular currentPatch++; - System.out.println("[DEBUG] (Pytorch) Patch " + currentPatch + "/" + totalPatch); + System.out.println("[DEBUG] (Inference) Patch " + currentPatch + "/" + totalPatch); if (log.getLevel() >= 1) log.print("currentPatch " + currentPatch); if (rp != null && rp.isStopped()) { @@ -386,8 +377,8 @@ public HashMap call() { } // TODO optimise (take the try out of the loop) - try (NDManager manager = NDManager.newBaseManager()) { - inputTensors = getInputTensors(manager, inputTensors, params.inputList, parameterMap, + try { + List> inputTensors = getInputTensors(params.inputList, parameterMap, patch, params.pytorchVersion); // TODO make easier to understand if (inputTensors == null) { @@ -400,10 +391,13 @@ public HashMap call() { // while executing the task if (rp != null) rp.allowStopping(false); - Predictor predictor = model.newPredictor(); - NDList outputTensors = predictor.predict(inputTensors); + List> outputTensorList = new ArrayList>(); + for (DijTensor outTensor : params.outputList) + outputTensorList.add(Tensor.buildEmptyTensor(outTensor.name, outTensor.form)); + + model.runModel(inputTensors, outputTensorList); // Close inputTensors to avoid memory leak - inputTensors.close(); + inputTensors.stream().forEach(tt -> tt.close()); if (rp != null) rp.allowStopping(true); // Check if the user has tried to stop the execution while loading the model @@ -416,32 +410,31 @@ public HashMap call() { for (DijTensor outTensor : params.outputList) { if (log.getLevel() >= 1) log.print("Session run " + (c+1) + "/" + params.outputList.size()); - NDArray result = outputTensors.get(c); + Tensor result = (Tensor) outputTensorList.get(c); if (outTensor.tensorType.contains("image") && !params.pyramidalNetwork) { - impatch[imCounter] = ImagePlus2Tensor.NDArray2ImagePlus(result, outTensor.form, outTensor.name, params.pytorchVersion); + impatch[imCounter] = ImagePlus2Tensor.tensor2ImagePlus(result.getData(), outTensor.form); imCounter ++; c ++; } else if (outTensor.tensorType.contains("image") && (params.pyramidalNetwork || !params.allowPatching)) { - outputImages[imCounter] = ImagePlus2Tensor.NDArray2ImagePlus(result, outTensor.form, outTensor.name, params.pytorchVersion); + outputImages[imCounter] = ImagePlus2Tensor.tensor2ImagePlus(result.getData(), outTensor.form); outputImages[imCounter].setTitle(outputTitles[imCounter]); outputImages[imCounter].show(); imCounter ++; c ++; } else if (outTensor.tensorType.contains("list")){ - ResultsTable table = Table2Tensor.tensorToTable(result, outTensor.form, outTensor.name, params.pytorchVersion); + ResultsTable table = Table2Tensor.tensorToTable(result); outputTables.add(table); table.show(outputTitles[c ++]); } // Check if the user has tried to stop the execution while loading the model // If they have return false and stop if (rp != null && rp.isStopped()) { - outputTensors.close(); - manager.close(); + model.closeModel(); + outputTensorList.stream().forEach(tt -> tt.close()); return null; } } - outputTensors.close(); - manager.close(); + outputTensorList.stream().forEach(tt -> tt.close()); } catch (IncorrectNumberOfDimensions ex) { ex.printStackTrace(); @@ -463,14 +456,6 @@ public HashMap call() { IJ.log("\n"); commentAboutPytorchVersions(); return null; - } catch (EngineException ex) { - ex.printStackTrace(); - error = dimensionsMismatch(ex.getMessage()); - IJ.log("Error applying the model"); - IJ.log("Check that the specifications for the input are compatible with the model architecture."); - IJ.log(error); - commentAboutPytorchVersions(); - return null; } catch (Exception ex) { ex.printStackTrace(); error = dimensionsMismatch(ex.getMessage()); @@ -641,35 +626,16 @@ private static Object getTensorFromMap(HashMap inputMap, DijTens return inputMap.get(tensor.name); } - private static NDList getInputTensors(NDManager manager, NDList tensorsArray, List inputTensors, HashMap paramsMap, + private static < T extends RealType< T > & NativeType< T > > List> getInputTensors(List inputTensors, HashMap paramsMap, ImagePlus im, String pytorchVersion){ - tensorsArray = new NDList(); + List> tensorsArray = new ArrayList>(); for (DijTensor tensor : inputTensors) { - if (tensor.tensorType.contains("parameter") && (paramsMap.get(tensor.name) instanceof NDArray)) { - NDArray tt = (NDArray) paramsMap.get(tensor.name); + if (tensor.tensorType.contains("parameter") && (paramsMap.get(tensor.name) instanceof Tensor)) { + Tensor tt = (Tensor) paramsMap.get(tensor.name); tensorsArray.add(tt); - } else if (tensor.tensorType.contains("parameter") && (paramsMap.get(tensor.name) instanceof Tensor)) { - Tensor t = (Tensor) paramsMap.get(tensor.name); - try { - final float[] out = new float[t.numElements()]; - FloatBuffer outBuff = FloatBuffer.wrap(out); - t.writeTo(outBuff); - NDArray tt = manager.create(out, new Shape(t.shape())); - tensorsArray.add(tt); - } catch (Exception ex) { - tensorsArray.close(); - ex.printStackTrace(); - return null; - } } else if (tensor.tensorType.contains("image")) { - try { - NDArray tt = ImagePlus2Tensor.imPlus2tensor(manager, im, tensor.form, pytorchVersion); - tensorsArray.add(tt); - } catch (Exception ex) { - tensorsArray.close(); - ex.printStackTrace(); - return null; - } + RandomAccessibleInterval tt = ImagePlus2Tensor.imPlus2tensor(im, tensor.form); + tensorsArray.add(Tensor.build(tensor.name, tensor.form, tt)); } } return tensorsArray; diff --git a/src/main/java/deepimagej/RunnerProgress.java b/src/main/java/deepimagej/RunnerProgress.java index 2303bb6f..705dbd87 100755 --- a/src/main/java/deepimagej/RunnerProgress.java +++ b/src/main/java/deepimagej/RunnerProgress.java @@ -257,10 +257,8 @@ public void info() { } processor.setText("Model Inference (GPU: " + gpuTag + ")"); - if (runner != null && (runner instanceof RunnerTf)) - patches.setText("Patches: " + ((RunnerTf) runner).getCurrentPatch() + "/" + ((RunnerTf) runner).getTotalPatch()); - if (runner != null && (runner instanceof RunnerPt)) - patches.setText("Patches: " + ((RunnerPt) runner).getCurrentPatch() + "/" + ((RunnerPt) runner).getTotalPatch()); + if (runner != null && (runner instanceof RunnerDL)) + patches.setText("Patches: " + ((RunnerDL) runner).getCurrentPatch() + "/" + ((RunnerDL) runner).getTotalPatch()); } diff --git a/src/main/java/deepimagej/RunnerTf.java b/src/main/java/deepimagej/RunnerTf.java deleted file mode 100755 index 29f79ab1..00000000 --- a/src/main/java/deepimagej/RunnerTf.java +++ /dev/null @@ -1,772 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej; - -import java.nio.FloatBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.concurrent.Callable; - -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; - -import ai.djl.ndarray.NDArray; -import deepimagej.exceptions.BatchSizeBiggerThanOne; -import deepimagej.tools.ArrayOperations; -import deepimagej.tools.CompactMirroring; -import deepimagej.tools.DijTensor; -import deepimagej.tools.Index; -import deepimagej.tools.Log; -import deepimagej.tools.NumFormat; -import ij.IJ; -import ij.ImagePlus; -import ij.measure.ResultsTable; - -public class RunnerTf implements Callable> { - - private HashMap inputMap; - private DeepImageJ dp; - private RunnerProgress rp; - private Log log; - private int currentPatch = 0; - private int totalPatch = 0; - public String error = ""; - - public RunnerTf(DeepImageJ dp, RunnerProgress rp,HashMap inputMap, Log log) { - this.dp = dp; - this.rp = rp; - this.log = log; - this.inputMap = inputMap; - log.print("constructor runner"); - } - - @Override - public HashMap call() { - - if (rp != null) - rp.setInfoTag("applyModel"); - if (rp != null && log.getLevel() >= 1) { - log.print("call runner"); - rp.setVisible(true); - } - - - Parameters params = dp.params; - // Load the model first - SavedModelBundle model = dp.getTfModel(); - - String sigeDefTag = params.developer ? params.graph : DeepLearningModel.returnStringSig(params.graph); - SignatureDef sig = DeepLearningModel.getSignatureFromGraph(model, DeepLearningModel.returnStringSig(sigeDefTag)); - - if (log.getLevel() >= 1) { - log.print("model " + (model == null)); - log.print("sig " + (sig == null)); - } - - if (!params.developer) { - String[] inputs = DeepLearningModel.returnTfInputs(sig); - for (int i = 0; i < inputs.length; i ++) { - if (DijTensor.retrieveByName(inputs[i], params.inputList) == null) { - DijTensor inp = new DijTensor(inputs[i]); - inp.tensorType = "parameter"; - inp.setInDimensions(DeepLearningModel.modelTfEntryDimensions(sig, inputs[i])); - params.inputList.add(inp); - } - } - } - // Map that contains the input tensors that are not images. - // TODO restrict patching (or not) if the input contains parameters - HashMap parameterMap = new HashMap(); - ImagePlus imp = null; - // Auxiliary array with the same number of images as output tensors - int c = 0; - int inputImageInd = 0; - for (DijTensor tensor : params.inputList) { - if (tensor.tensorType.contains("image")) { - imp = getImageFromMap(inputMap, tensor); - if (imp == null) { - // TODO maybe we should allow running models without images - error = "No image provided."; - return null; - } - String inputPixelSizeX = ((float) imp.getCalibration().pixelWidth) + " " + imp.getCalibration().getUnit(); - String inputPixelSizeY = ((float) imp.getCalibration().pixelHeight) + " " + imp.getCalibration().getUnit(); - String inputPixelSizeZ = ((float) imp.getCalibration().pixelDepth) + " " + imp.getCalibration().getUnit(); - int[] dims = imp.getDimensions(); - params.inputList.get(c).inputTestSize = Integer.toString(dims[0]) + " x " + Integer.toString(dims[1]) + " x " + Integer.toString(dims[2]) + " x " + Integer.toString(dims[3]);; - params.inputList.get(c).inputPixelSizeX = inputPixelSizeX; - params.inputList.get(c).inputPixelSizeY = inputPixelSizeY; - params.inputList.get(c).inputPixelSizeZ = inputPixelSizeZ; - inputImageInd = c; - } else if (tensor.tensorType.contains("parameter")){ - Object tensorVal = getTensorFromMap(inputMap, tensor); - if (tensorVal == null) { - error = "The input tensor '" + tensor.name + "' should be given by" - + "the preprocessing but it is not."; - IJ.error(error); - return null; - } else if (tensorVal instanceof Tensor) { - parameterMap.put(tensor.name, (Tensor) tensorVal); - } else if (tensorVal instanceof NDArray) { - parameterMap.put(tensor.name, (NDArray) tensorVal); - } else { - // TODO improve error message and review what can be a preprocessing output - error = "Output of the preprocessing should be either a Tensor object" - + " or a NDArray object"; - IJ.error(error); - return null; - } - } - c ++; - } - - int outputImagesCount = 0; - for (DijTensor tensor : params.outputList) { - if (tensor.tensorType.contains("image")) - outputImagesCount ++; - } - ImagePlus[] outputImages = new ImagePlus[outputImagesCount]; - List outputTables = new ArrayList(); - - if (imp == null) { - // TODO maybe we should allow running models without images - error = "No image provided."; - return null; - } - int nx = imp.getWidth(); - int ny = imp.getHeight(); - int nc = imp.getNChannels(); - int nz = imp.getNSlices(); - - if (log.getLevel() >= 1) - log.print("image size " + nx + "x" + ny + "x" + nz); - - int[] indices = new int[4]; - String[] dimLetters = "XYCZ".split(""); - for (int i = 0; i < dimLetters.length; i ++) - indices[i] = Index.indexOf(params.inputList.get(inputImageInd).form.split(""), dimLetters[i]); - - int[] patchSize = {1, 1, 1, 1}; - int[] step = {1, 1, 1, 1}; - int[] minSize = {1, 1, 1, 1}; - for (int i = 0; i < indices.length; i ++) { - if (indices[i] != -1) { - patchSize[i] = params.inputList.get(inputImageInd).recommended_patch[indices[i]]; - step[i] = params.inputList.get(inputImageInd).step[indices[i]]; - minSize[i] = params.inputList.get(inputImageInd).minimum_size[indices[i]]; - } - } - - // TODO improve - if (params.pyramidalNetwork || !params.allowPatching) { - for (c = 0; c < patchSize.length; c ++) { - if (step[c] != 0 && patchSize[c] != imp.getDimensions()[c]) { - patchSize[c] = (int) Math.ceil((double) (imp.getDimensions()[c] - minSize[c]) / step[c]) * step[c] + minSize[c]; - } else if (patchSize[c] < imp.getDimensions()[c] && step[c] == 0) { - String errorMsg = "This model only accepts images with input size smaller or equal to:"; - for (int i = 0; i < dimLetters.length; i ++) { - errorMsg += "\n" + dimLetters[i] + " : " + patchSize[i]; - } - IJ.error(errorMsg); - return null; - } - } - } - - int px = patchSize[0]; int py = patchSize[1]; int pc = patchSize[2]; int pz = patchSize[3]; - - if (!ArrayOperations.isImageSizeAcceptable(new int[] {nx, ny, nc, nz}, patchSize, params.inputList.get(inputImageInd).form)) { - if (rp != null) - rp.stop(); - return null; - } - - if (log.getLevel() >= 1) - log.print("patch size " + "X: " + px + ", Y: " + py + ", Z: " + pz + ", C: " + pc); - - // To define the runtime for config.xml. Starting time - long startingTime = System.nanoTime(); - // Create the image that is going to be fed to the graph - ImagePlus[] impatch = new ImagePlus[outputImages.length]; - - String[] outputTitles = new String[params.outputList.size()]; - // Reset the counter to 0 use it again - c = 0; - int extensionInd = imp.getTitle().lastIndexOf('.'); - String imName = extensionInd == -1 ? imp.getTitle() : imp.getTitle().substring(0, extensionInd); - for (DijTensor outName: params.outputList) - outputTitles[c++] = dp.getName() + "_" + outName.name + "_" + imName; - - // Order of the dimensions. For example "NHWC"-->Batch size, Height, Width, Channels - String inputForm = params.inputList.get(inputImageInd).form; - int[] inputDims = params.inputList.get(inputImageInd).tensor_shape; - int channelPos = Index.indexOf(inputForm.split(""), "C"); - int[] inDim = imp.getDimensions(); - if (inDim[2] != inputDims[channelPos] && inputDims[channelPos] != -1) { - error = "The number of channels of the input image is incorrect.\n" - + "The models requires " + inputDims[channelPos] + "channels " - + "but the input image provided has " + inDim[2]; - IJ.error(error); - return null; - } - // Get the padding in case the image needs any - int[] padding = new int[4]; - if (!params.pyramidalNetwork) { - padding = findTotalPadding(params.outputList); - } - int roiX = px - padding[0] * 2; - int roiY = py - padding[1] * 2; - int roiZ = pz - padding[3] * 2; - int roiC = pc - padding[2] * 2; - int npx = (int) Math.ceil((double)nx / (double)roiX); - int npy = (int) Math.ceil((double)ny / (double)roiY); - int npc = (int) Math.ceil((double)nc / (double)roiC); - int npz = (int) Math.ceil((double)nz / (double)roiZ); - if (!params.allowPatching) { - npx = 1; npy = 1; npz = 1; npc = 1; - } - currentPatch = 0; - totalPatch = npx * npy * npz * npc; - - int[] roi = {roiX, roiY, roiC, roiZ}; - int[] size = {nx, ny, nc, nz}; - int[][] mirrorPixels = ArrayOperations.findAddedPixels(size, padding, roi); - ImagePlus mirrorImage = CompactMirroring.mirrorXY(imp, mirrorPixels[0][0], mirrorPixels[1][0], - mirrorPixels[0][1], mirrorPixels[1][1], - mirrorPixels[0][3], mirrorPixels[1][3]); - if (log.getLevel() == 2) { - mirrorImage.setTitle("Extended image"); - mirrorImage.getProcessor().resetMinAndMax(); - mirrorImage.show(); - } - - // If the roi of the patch is bigger than the actual image wanted, consider all the - // remaining pixels as overlap (padding). Consider that now there might be then different - // padding for X and Y - int overlapX = mirrorPixels[0][0]; - if (roiX > nx) { - roiX = nx; - padding[0] = (px - nx) / 2; - overlapX = (px - nx) / 2; - } - - int overlapY = mirrorPixels[0][1]; - if (roiY > ny) { - roiY = ny; - padding[1] = (py - ny) / 2; - overlapY = (py - ny) / 2; - } - - int overlapZ = mirrorPixels[0][3]; - if (roiZ > nz) { - roiZ = nz; - padding[3] = (pz - nz) / 2; - overlapZ = (pz - nz) / 2; - } - - if (log.getLevel() >= 1) - log.print("start " + npx + "x" + npy); - - for (int i = 0; i < npx; i++) { - for (int j = 0; j < npy; j++) { - for (int z = 0; z < npz; z++) { - // TODO reduce this mega big loop to something more modular - currentPatch++; - System.out.println("[DEBUG] (Tensorflow) Patch " + currentPatch + "/" + totalPatch); - if (log.getLevel() >= 1) - log.print("currentPatch " + currentPatch); - if (rp != null && rp.isStopped()) { - rp.stop(); - return null; - } - // Variables to track when the roi starts in the mirror image - int xMirrorStartPatch; - int yMirrorStartPatch; - int zMirrorStartPatch; - - // Variables to track when the roi starts in the patch - int xImageStartPatch; - int xImageEndPatch; - int yImageStartPatch; - int yImageEndPatch; - int zImageStartPatch; - int zImageEndPatch; - int leftoverPixelsX; - int leftoverPixelsY; - int leftoverPixelsZ; - if (i < npx -1 || npx == 1) { - xMirrorStartPatch = padding[0] + roiX*i; - - xImageStartPatch = roiX*i; - xImageEndPatch = roiX*(i + 1); - leftoverPixelsX = overlapX; - } else { - xMirrorStartPatch = nx + padding[0] - roiX; - - xImageStartPatch = roiX*i; - xImageEndPatch = nx; - leftoverPixelsX = overlapX + roiX - (xImageEndPatch - xImageStartPatch); - } - - if (j < npy - 1 || npy == 1) { - yMirrorStartPatch = padding[1] + roiY*j; - - yImageStartPatch = roiY*j; - yImageEndPatch = roiY*(j + 1); - leftoverPixelsY = overlapY; - } else { - yMirrorStartPatch = ny + padding[1] - roiY; - - yImageStartPatch = roiY*j; - yImageEndPatch = ny; - leftoverPixelsY = overlapY + roiY - (yImageEndPatch - yImageStartPatch); - } - - if (z < npz - 1 || npz == 1) { - zMirrorStartPatch = padding[3] + roiZ*z; - - zImageStartPatch = roiZ*z; - zImageEndPatch = roiZ*(z + 1); - leftoverPixelsZ = overlapZ; - } else { - zMirrorStartPatch = nz + padding[3] - roiZ; - - zImageStartPatch = roiZ*z; - zImageEndPatch = nz; - leftoverPixelsZ = overlapZ + roiZ- (zImageEndPatch - zImageStartPatch); - } - - // TODO mirar en profundidad. Que pasa cuando el mirror no es igual de grande que le patch - // Observe que se compensaba erroneamente - ImagePlus patch = ArrayOperations.extractPatch(mirrorImage, patchSize, xMirrorStartPatch, yMirrorStartPatch, - zMirrorStartPatch, overlapX, overlapY, overlapZ); - if (log.getLevel() >= 1) - log.print("Extract Patch (" + (i + 1) + ", " + (j + 1) + ") patch size: " + patch.getWidth() + "x" + patch.getHeight() + " pixels"); - if (log.getLevel() == 2) { - patch.setTitle("Patch (" + i + "," + j + ")"); - patch.getProcessor().resetMinAndMax(); - } - - Tensor[] inputTensors = getInputTensors(params.inputList, parameterMap, patch, pc); - Session.Runner sess = model.session().runner(); - - for (int k = 0; k < params.inputList.size(); k++) { - // The thread cannot be stopped while loading a model, thus block the button - // while executing the task - if (rp != null ) - rp.allowStopping(false); - sess = sess.feed(opName(sig.getInputsOrThrow(params.inputList.get(k).name)), inputTensors[k]); - if (rp != null ) - rp.allowStopping(true); - // Check if the user has tried to stop the execution while loading the model - // If they have return false and stop - if (rp != null && rp.isStopped()) - return null; - } - // Reinitialise the counter - c = 1; - for (DijTensor outTensor : params.outputList) { - // The thread cannot be stopped while loading a model, thus block the button - // while executing the task - if (rp != null ) - rp.allowStopping(false); - sess = sess.fetch(opName(sig.getOutputsOrThrow(outTensor.name))); - if (log.getLevel() >= 1) - log.print("Session fetch " + (c ++)); - if (rp != null ) - rp.allowStopping(true); - // Check if the user has tried to stop the execution while loading the model - // If they have return false and stop - if(rp != null && rp.isStopped()) - return null; - } - try { - // The thread cannot be stopped while loading a model, thus block the button - // while executing the task - if (rp != null ) - rp.allowStopping(false); - List> fetches = sess.run(); - if (rp != null ) - rp.allowStopping(true); - // Check if the user has tried to stop the execution while loading the model - // If they have return false and stop - if (rp != null && rp.isStopped()) - return null; - // Reinitialise counter - c = 0; - int imCounter = 0; - for (DijTensor outTensor : params.outputList) { - if (log.getLevel() >= 1) - log.print("Session run " + (c+1) + "/" + params.outputList.size()); - Tensor result = fetches.get(c); - if (outTensor.tensorType.contains("image") && !params.pyramidalNetwork && params.allowPatching) { - impatch[imCounter] = ImagePlus2Tensor.tensor2ImagePlus(result, outTensor.form, outTensor.name); - imCounter ++; - c ++; - } else if (outTensor.tensorType.contains("image") && (params.pyramidalNetwork || !params.allowPatching)) { - outputImages[imCounter] = ImagePlus2Tensor.tensor2ImagePlus(result, outTensor.form, outTensor.name); - outputImages[imCounter].setTitle(outputTitles[c ++]); - outputImages[imCounter].show(); - imCounter ++; - } else if (outTensor.tensorType.contains("list")){ - ResultsTable table = Table2Tensor.tensorToTable(result, outTensor.form, outTensor.name); - outputTables.add(table); - table.show(outputTitles[c ++]); - } - result.close(); - // TODO put in a method - // Check if the user has tried to stop the execution while loading the model - // If they have return false and stop - if (rp != null && rp.isStopped()) { - // Close every tensor and stop - // Close input tensors - for (int ii = 0; ii < inputTensors.length; ii ++) - inputTensors[ii].close(); - for (Tensor oo : fetches) - oo.close(); - return null; - } - - } - // Close input tensors - for (int ii = 0; ii < inputTensors.length; ii ++) { - inputTensors[ii].close(); - } - } - catch(IllegalArgumentException ex) { - ex.printStackTrace(); - error = "Incorrect input dimensions"; - IJ.log("Error applying the model"); - IJ.log("The dimensions of the input are incorrect."); - IJ.log("The model might require only specific input sizes."); - IJ.log("Another of the possible options is that the model has an encoder decoder\n" - + "architecture that requires input to be divisible a certain amount of times."); - IJ.log("Please review the model architecture and the step and patch parameters."); - return null; - } catch(BatchSizeBiggerThanOne ex) { - ex.printStackTrace(); - error = "Output batch size bigger than 1 for tensor '" + ex.getName() + "'.\n Batch_size > 1 not supported by this version of DeepImageJ"; - IJ.log("Error applying the model"); - IJ.log(error); - IJ.log(ex.toString()); - return null; - } catch(IllegalStateException ex) { - ex.printStackTrace(); - error = "Missing weights"; - IJ.log("Error applying the model"); - IJ.log("Uninitialized weights."); - IJ.log("Check that the variables/weights folder contains a correct version of the weights"); - return null; - } - catch (Exception ex) { - // TODO MAKE THIS EXCEPTION MORE ESPECIFIC - ex.printStackTrace(); - IJ.log("Error applying the model"); - return null; - } - float[][] allOffsets = findOutputOffset(params.outputList); - int imCounter = 0; - for (int counter = 0; counter < params.outputList.size(); counter++) { - if (params.outputList.get(counter).tensorType.contains("image") && !params.pyramidalNetwork && params.allowPatching) { - float[] outSize = findOutputSize(size, params.outputList.get(counter), params.inputList, impatch[imCounter].getDimensions()); - if (outputImages[imCounter] == null) { - int[] dims = impatch[imCounter].getDimensions(); - outputImages[imCounter] = IJ.createHyperStack(outputTitles[counter], (int)outSize[0], (int)outSize[1], (int)outSize[2], (int)outSize[3], dims[4], 32); - outputImages[imCounter].getProcessor().resetMinAndMax(); - outputImages[imCounter].show(); - } - float scaleX = outSize[0] / nx; float scaleY = outSize[1] / ny; float scaleZ = outSize[3] / nz; - ArrayOperations.imagePlusReconstructor(outputImages[imCounter], impatch[imCounter], (int) (xImageStartPatch * scaleX), - (int) (xImageEndPatch * scaleX), (int) (yImageStartPatch * scaleY), (int) (yImageEndPatch * scaleY), - (int) (zImageStartPatch * scaleZ), (int) (zImageEndPatch * scaleZ),(int)(leftoverPixelsX * scaleX + Math.ceil(allOffsets[imCounter][0])), - (int)(leftoverPixelsY * scaleY + Math.ceil(allOffsets[imCounter][1])), (int)(leftoverPixelsZ * scaleZ + Math.ceil(allOffsets[imCounter][3]))); - if (outputImages[imCounter] != null) - outputImages[imCounter].getProcessor().resetMinAndMax(); - if (rp != null && rp.isStopped()) { - rp.stop(); - return null; - } - imCounter ++; - } else if (params.outputList.get(counter).tensorType.contains("image") && params.pyramidalNetwork) { - // TODO improve - int[] outPatchDims = outputImages[imCounter].getDimensions(); - String[] ijForm = "XYCZB".split(""); - String dijForm = params.outputList.get(counter).form; - int[] pyramidOut = params.outputList.get(counter).sizeOutputPyramid; - for (int dd = 0; dd < ijForm.length; dd ++) { - int idx = dijForm.indexOf(ijForm[dd]); - if (idx == -1 && outPatchDims[dd] == 1) { - continue; - } else if (idx != -1 && outPatchDims[dd] == pyramidOut[idx]) { - continue; - } - IJ.error("The dimensions of the output image do not coincide\n" - + "with the dimensions specified previously:\n" - + "Specified output dimensions: dimension order -> " + dijForm + ", dimension size -> " + Arrays.toString(pyramidOut) - + "Actual output dimensions: dimension order -> XYCZB, dimension size -> " + Arrays.toString(outPatchDims)); - error = "Error specifying output dimensions."; - return null; - } - if (rp != null && rp.isStopped()) { - rp.stop(); - return null; - } - imCounter ++; - } else if (params.outputList.get(counter).tensorType.contains("image") && !params.pyramidalNetwork && !params.allowPatching) { - // TODO improve - int[] outPatchDims = outputImages[imCounter].getDimensions(); - String[] ijForm = "XYCZB".split(""); - String dijForm = params.outputList.get(counter).form; - float[] scale = params.outputList.get(counter).scale; - float[] offset = params.outputList.get(counter).offset; - // TODO adapt for more inputs - // We take the mirrored image as the reference, because that is what ends - // up going into the model - int[] refSize = mirrorImage.getDimensions(); - String thSizeStr = "["; - for (int dd = 0; dd < ijForm.length; dd ++) { - int idx = dijForm.indexOf(ijForm[dd]); - if (idx == -1 && outPatchDims[dd] == scale[idx]) { - thSizeStr += scale[idx] + ","; - continue; - } else if (idx != -1 && outPatchDims[dd] == (int)(refSize[dd] * scale[idx]) + 2 * offset[idx]) { - thSizeStr += ((int)(refSize[dd] * scale[idx]) + 2 * offset[idx]) + ","; - continue; - } - for (dd ++; dd < ijForm.length;) { - idx = dijForm.indexOf(ijForm[dd]); - if (idx == -1) { - thSizeStr += scale[idx] + ","; - } else if (idx != -1) { - thSizeStr += ((int)(refSize[dd] * scale[idx]) + 2 * offset[idx]) + ","; - } - } - thSizeStr = thSizeStr.substring(0, thSizeStr.length() - 1) + "]"; - IJ.error("The dimensions of the output image do not coincide\n" - + "with the dimensions specified previously:\n" - + "Specified output dimensions: dimension order -> XYCZB, dimension size -> " + thSizeStr - + "Actual output dimensions: dimension order -> XYCZB, dimension size -> " + Arrays.toString(outPatchDims)); - error = "Error specifying output dimensions."; - return null; - } - if (rp != null && rp.isStopped()) { - rp.stop(); - return null; - } - imCounter ++; - } - } - if (log.getLevel() >= 1) - log.print("Create Output "); - } - } - } - - // To define the runtime. End time - long endTime = System.nanoTime(); - params.runtime = NumFormat.seconds(endTime - startingTime); - // Set Parameter params.memoryPeak - if (rp != null ) - params.memoryPeak = NumFormat.bytes(rp.getPeakmem()); - // Set Parameter params.outputSize - HashMap outputMap = new HashMap(); - int imageCount = 0; - int tableCount = 0; - c = 0; - for (DijTensor tensor : params.outputList) { - if (tensor.tensorType.contains("image")) { - ImagePlus im = outputImages[imageCount]; - im.setPosition(1, 1, 1); - im.getProcessor().resetMinAndMax(); - // Add the image to the output map - outputMap.put(tensor.name, im); - } else if (tensor.tensorType.contains("list")) { - // Add the results table to the output map - outputMap.put(tensor.name, outputTables.get(tableCount ++)); - } - } - - - return outputMap; - } - - private static ImagePlus getImageFromMap(HashMap inputMap, DijTensor tensor) { - if (!inputMap.containsKey(tensor.name)){ - IJ.error("Preprocessing should provide a HashMap with\n" - + "the key " + tensor.name); - return null; - } else if (!(inputMap.get(tensor.name) instanceof ImagePlus)) { - IJ.error("The input " + tensor.name + " should" - + " be an instance of an ImagePlus."); - return null; - } - ImagePlus imp = (ImagePlus) inputMap.get(tensor.name); - return imp; - } - - private static Object getTensorFromMap(HashMap inputMap, DijTensor tensor){ - if (!inputMap.containsKey(tensor.name)){ - IJ.error("Preprocessing should provide a HashMap with\n" - + "the key " + tensor.name); - return null; - } else if (!(inputMap.get(tensor.name) instanceof Tensor)) { - IJ.error("The input " + tensor.name + " should" - + " be an instance of a Tensor."); - return null; - } - return inputMap.get(tensor.name); - } - - private static Tensor[] getInputTensors(List inputTensors, HashMap paramsMap, - ImagePlus im, int pc){ - Tensor[] tensorsArray = new Tensor[inputTensors.size()]; - int c = 0; - for (DijTensor tensor : inputTensors) { - if (tensor.tensorType.contains("parameter") && paramsMap.get(tensor.name) instanceof Tensor) { - tensorsArray[c ++] = (Tensor) paramsMap.get(tensor.name); - } else if (tensor.tensorType.contains("parameter") && paramsMap.get(tensor.name) instanceof NDArray) { - NDArray t = (NDArray) paramsMap.get(tensor.name); - final float[] out = t.toFloatArray(); - FloatBuffer outBuff = FloatBuffer.wrap(out); - tensorsArray[c ++] = Tensor.create(t.getShape().getShape(), outBuff); - } else { - tensorsArray[c ++] = ImagePlus2Tensor.implus2TensorFloat(im, tensor.form); - } - } - return tensorsArray; - } - - private static float[] findOutputSize(int[] inpSize, DijTensor outTensor, List inputList, int[] patchSize) { - String refForOutput = outTensor.referenceImage; - DijTensor refTensor = DijTensor.retrieveByName(refForOutput, inputList); - float[] outSize = new float[inpSize.length]; - String[] standarForm = "XYCZ".split(""); - for (int i = 0; i < outSize.length; i ++) { - int indOut = Index.indexOf(outTensor.form.split(""), standarForm[i]); - int indInp = -1; - if (refTensor != null) - indInp = Index.indexOf(refTensor.form.split(""), standarForm[i]); - if (indOut != -1 && indInp != -1) { - if (standarForm[i].toLowerCase().equals("c")) - outSize[i] = inpSize[i] * outTensor.scale[indOut] + 2*outTensor.offset[indOut]; - else - outSize[i] = inpSize[i] * outTensor.scale[indOut]; - } else if (indOut != -1 && indInp == -1) { - outSize[i] = patchSize[i]; - } else { - outSize[i] = 1; - } - } - return outSize; - } - - private String opName(final TensorInfo t) { - final String n = t.getName(); - if (n.endsWith(":0")) { - return n.substring(0, n.lastIndexOf(":0")); - } - return n; - } - - public static int[] findTotalPadding(List outputs) { - // Create an object of int[] that contains the output dimensions - // of each patch. - // This dimensions are always of the form [x, y, c, d] - int[] padding = {0, 0, 0, 0}; - String[] form = "XYCZ".split(""); - for (DijTensor out: outputs) { - if (!out.tensorType.equals("image")) - continue; - for (int i = 0; i < form.length; i ++) { - int ind = Index.indexOf(out.form.split(""), form[i]); - if (out.tensorType.contains("image") && ind != -1 && !form[i].equals("B") && !form[i].equals("C")) { - double totalPad = Math.ceil(-1 * (double)out.offset[ind] / (double)out.scale[ind]) + Math.ceil((double)out.halo[ind] / (double)out.scale[ind]); - if ((int) totalPad > padding[i]) { - padding[i] = (int) totalPad; - } - } - } - } - return padding; - } - - // TODO clean up method (line 559) Make it stable for pyramidal - public static float[][] findOutputOffset(List outputs) { - // Create an object of int[] that contains the output dimensions - // of each patch. - // This dimensions are always of the form [x, y, c, d] - float[][] offsets = new float[outputs.size()][4]; - String[] form = "XYCZ".split(""); - int c1 = 0; - for (DijTensor out: outputs) { - if (!out.tensorType.toLowerCase().equals("image")) - continue; - int c2 = 0; - for (int i = 0; i < offsets[0].length; i ++) { - int ind = Index.indexOf(out.form.split(""), form[i]); - if (ind != -1 && out.offset != null) { - offsets[c1][c2] = out.offset[ind]; - } - c2 ++; - } - c1 ++; - } - return offsets; - } - - public int getCurrentPatch() { - return currentPatch; - } - - public int getTotalPatch() { - return totalPatch; - } - -} diff --git a/src/main/java/deepimagej/Table2Tensor.java b/src/main/java/deepimagej/Table2Tensor.java index 728e6312..15eca40e 100755 --- a/src/main/java/deepimagej/Table2Tensor.java +++ b/src/main/java/deepimagej/Table2Tensor.java @@ -44,20 +44,24 @@ package deepimagej; -import java.nio.FloatBuffer; import java.util.ArrayList; import java.util.Arrays; -import org.tensorflow.Tensor; - -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.Shape; import deepimagej.exceptions.BatchSizeBiggerThanOne; import deepimagej.exceptions.IncorrectNumberOfDimensions; import deepimagej.tools.Index; import ij.IJ; import ij.measure.ResultsTable; +import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.utils.IndexingUtils; +import net.imglib2.Cursor; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.img.array.ArrayImgFactory; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.view.IntervalView; public class Table2Tensor { @@ -117,21 +121,31 @@ public static ResultsTable flatArrayToTable(float[] flatArray, long[] shape, Str /* * Convert NDArrays into results table */ - public static ResultsTable tensorToTable(Tensor tensor, String form, String name) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne { + public static < T extends RealType< T > & NativeType< T > > ResultsTable + tensorToTable(Tensor tensor) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne { + return imgLib2ToTable(tensor.getData(), tensor.getAxesOrderString(), tensor.getName()); + } + + + /* + * Convert NDArrays into results table + */ + public static < T extends RealType< T > & NativeType< T > > ResultsTable + imgLib2ToTable(RandomAccessibleInterval data, String axes, String name) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne { - long[] shape = tensor.shape(); - if (form == null) - form = findTableForm(shape, name); + long[] shape = data.dimensionsAsLongArray(); + if (axes == null) + axes = findTableForm(shape, name); // If DeepImageJ has not been able to induce a form, the tensor is not valid - if (form == null) + if (axes == null) return null; // Check that the output dimensions correspond to the form length - if (shape.length != form.length()) - throw new IncorrectNumberOfDimensions(shape, form, name); + if (shape.length != axes.length()) + throw new IncorrectNumberOfDimensions(shape, axes, name); // TODO add possibility of batch>1 - int batchIndex = form.indexOf("B"); + int batchIndex = axes.indexOf("B"); if (batchIndex != -1 && shape[batchIndex] > 1) - throw new BatchSizeBiggerThanOne(shape, form, name); + throw new BatchSizeBiggerThanOne(shape, axes, name); // For the moment DeepImageJ only supports 2D tables, thus // if the tensor has more than to dimensions greater than one, @@ -143,60 +157,30 @@ public static ResultsTable tensorToTable(Tensor tensor, String form, String n + "Represent the tensor as an image, instead of as a list."); return null; } - // Array of one dimension containing all the data from the tensor int arraySize = 1; for (long el : shape) arraySize = arraySize * ((int) el); - float[] flatArray = new float[arraySize]; + float[] flatArray = new float[arraySize]; - FloatBuffer outBuff = FloatBuffer.wrap(flatArray); - tensor.writeTo(outBuff); - return flatArrayToTable(flatArray, shape, form); - } - - - /* - * Convert NDArrays into results table - */ - public static ResultsTable tensorToTable(NDArray tensor, String form, String name, String ptVersion) throws IncorrectNumberOfDimensions, BatchSizeBiggerThanOne { - // Array of one dimension containing all the data from the tensor - float[] flatArray = tensor.toFloatArray(); - long[] shape = tensor.getShape().getShape(); - if (form == null) - form = findTableForm(shape, name); - // If DeepImageJ has not been able to induce a form, the tensor is not valid - if (form == null) - return null; - int batchIndex = form.indexOf("B"); - // TODO should batch be eliminated always or only when the dimensions are incorrect - boolean old = ImagePlus2Tensor.olderThanPytorch170(ptVersion); - if (old && batchIndex != -1) { - String oldForm = "" + form; - form = oldForm.substring(0, batchIndex) + oldForm.substring(batchIndex + 1); - IJ.log("WARNING: DJL Pytorch versions <=1.6.0 do not allow definition of the batch size."); - IJ.log("WARNING: Output tensor " + name + " dimension organization has changed: " + oldForm + " --> " + form); - } - // REtrieve again the batch index - batchIndex = form.indexOf("B"); - // TODO add possibility of batch>1 - if (batchIndex != -1 && shape[batchIndex] > 1) - throw new BatchSizeBiggerThanOne(shape, form, name); - // Check that the output dimensions correspond to the form length - if (shape.length != form.length()) - throw new IncorrectNumberOfDimensions(shape, form, name); - - // For the moment DeepImageJ only supports 2D tables, thus - // if the tensor has more than to dimensions greater than one, - // the plugin throws an exception - // TODO support more than 2d tables - ArrayList non1Occurences = findNon1occurences(shape); - if (non1Occurences.size() > 2) { - IJ.error("For the moment DeepImageJ only supports 2D tables as outputs with batch_size = 1.\n" - + "Represent the tensor as an image, instead of as a list."); - return null; + Cursor tensorCursor; + if (data instanceof IntervalView) + tensorCursor = ((IntervalView) data).cursor(); + else if (data instanceof Img) + tensorCursor = ((Img) data).cursor(); + else throw new IllegalArgumentException("The data of the " + Tensor.class + + " has " + "to be an instance of " + Img.class + " or " + + IntervalView.class); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, + shape); + float val = tensorCursor.get().getRealFloat(); + flatArray[flatPos] = val; } - return flatArrayToTable(flatArray, shape, form); + + return flatArrayToTable(flatArray, shape, axes); } /* @@ -395,7 +379,7 @@ public static float[] tableToFlatArray(ResultsTable rt, String form, long[] arra * Method that gets a Tensor from a ResultsTable */ // TODO can this method be used somewhere in the plugin? - public static NDArray tableToTensor(ResultsTable rt, String form, String ptVersion, NDManager manager){ + public static Img tableToTensor(ResultsTable rt, String form){ // Get rows and columns (by default sum 1 to the number of columns) int rSize = rt.size(); // Get last column indicates position, that is why we sum 1 @@ -409,55 +393,25 @@ public static NDArray tableToTensor(ResultsTable rt, String form, String ptVersi IJ.error("Table has 2 dimensions but only one (" + form + ") was specified."); return null; } - - boolean old = ImagePlus2Tensor.olderThanPytorch170(ptVersion); - int batchIndex = form.indexOf("B"); - // TODO should batch be eliminated always or only when the dimensions are incorrect - if (old && batchIndex != -1) { - String oldForm = "" + form; - form = oldForm.substring(0, batchIndex) + oldForm.substring(batchIndex + 1); - IJ.log("WARNING: DJL Pytorch versions <=1.6.0 do not allow definition of the batch size."); - IJ.log("WARNING: List input tensor dimension organization has changed: " + oldForm + " --> " + form); - } - // Create a variable that acts as imp.getDimensions() for ImagePlus types int[] defaultTableDimensions = new int[] {rSize, cSize}; long[] arrayShape = getTableTensorDims(defaultTableDimensions, form); // Get the array - float[] flatRt = tableToFlatArray(rt, form, arrayShape); - // Create the tensor - FloatBuffer outBuff = FloatBuffer.wrap(flatRt); - NDArray tensor = manager.create(flatRt, new Shape(arrayShape)); - return tensor; - } - - /* - * Method that gets a Tensor from a ResultsTable - */ - // TODO can this method be used somewhere in the plugin? - public static Tensor tableToTensor(ResultsTable rt, String form){ - // Get rows and columns (by default sum 1 to the number of columns) - int rSize = rt.size(); - // Get last column indicates position, that is why we sum 1 - int cSize = rt.getLastColumn() + 1; - if (cSize == 0) - cSize = 1; - if (cSize != 1 && rSize != 1 && form.indexOf("B") != -1) { - IJ.error("Batch size should be 1."); - return null; - } else if (cSize != 1 && rSize != 1 && form.length() == 1) { - IJ.error("Table has 2 dimensions but only one (" + form + ") was specified."); - return null; + float[] flatRt = tableToFlatArray(rt, form, arrayShape); + + + final ArrayImgFactory factory = new ArrayImgFactory<>(new FloatType()); + final Img outputImg = factory.create(arrayShape); + Cursor tensorCursor = outputImg.cursor(); + while (tensorCursor.hasNext()) { + tensorCursor.fwd(); + long[] cursorPos = tensorCursor.positionAsLongArray(); + int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, + arrayShape); + float val = flatRt[flatPos]; + tensorCursor.get().set(val); } - // Create a variable that acts as imp.getDimensions() for ImagePlus types - int[] defaultTableDimensions = new int[] {rSize, cSize}; - long[] arrayShape = getTableTensorDims(defaultTableDimensions, form); - // Get the array - float[] flatRt = tableToFlatArray(rt, form, arrayShape); - // Create the tensor - FloatBuffer outBuff = FloatBuffer.wrap(flatRt); - Tensor tensor = Tensor.create(arrayShape, outBuff); - return tensor; + return outputImg; } /* diff --git a/src/main/java/deepimagej/modelrunner/EngineInstaller.java b/src/main/java/deepimagej/modelrunner/EngineInstaller.java new file mode 100644 index 00000000..2699f292 --- /dev/null +++ b/src/main/java/deepimagej/modelrunner/EngineInstaller.java @@ -0,0 +1,88 @@ +package deepimagej.modelrunner; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import io.bioimage.modelrunner.bioimageio.download.DownloadTracker; +import io.bioimage.modelrunner.bioimageio.download.DownloadTracker.TwoParameterConsumer; +import io.bioimage.modelrunner.utils.Log; + +public class EngineInstaller { + + private static String TOTAL_PROGRESS_STRING = ""; + private static String TOTAL_REMAINING_STRING = ""; + private static int N_BINS = 10; + + private HashMap timesMap = new HashMap(); + + static { + for (int i = 0; i < N_BINS; i ++) { + TOTAL_PROGRESS_STRING += "#"; + TOTAL_REMAINING_STRING += "."; + } + } + + /** + * Create a String that summarizes the information about the download of the + * engines specifies by the parameter 'basicEng' and the real time information + * about the download contained in the consumer + * @param consumers + * @return + */ + public String basicEnginesInstallationProgress( + Map> consumers) { + + String progress = ""; + for (Entry> entry : consumers.entrySet()) { + TwoParameterConsumer consumer = entry.getValue(); + double totalProgress = consumer.get().keySet().contains(DownloadTracker.TOTAL_PROGRESS_KEY) ? + consumer.get().get(DownloadTracker.TOTAL_PROGRESS_KEY) : 0.0; + if (totalProgress == 0.0) + continue; + if (!this.timesMap.keySet().contains(entry.getKey())) + timesMap.put(entry.getKey(), Log.gct()); + String timeKey = timesMap.get(entry.getKey()); + progress += System.lineSeparator(); + progress += " - " + timeKey + " -- " + new File(entry.getKey()).getName(); + + progress += " " + getProgressPerc(totalProgress) + System.lineSeparator(); + for (Entry fEntry : consumer.get().entrySet()) { + if (fEntry.getKey().equals(DownloadTracker.TOTAL_PROGRESS_KEY)) + continue; + if (!this.timesMap.keySet().contains(fEntry.getKey())) + timesMap.put(fEntry.getKey(), Log.gct()); + String timeKey2 = timesMap.get(fEntry.getKey()); + progress += " -- " + timeKey2 + " -- " + new File(fEntry.getKey()).getName(); + progress += " " + getProgressPerc(fEntry.getValue()) + System.lineSeparator(); + + } + } + if (!progress.equals("") || consumers.keySet().size() == 0) + return progress; + for (Entry> entry : consumers.entrySet()) { + if (!this.timesMap.keySet().contains(entry.getKey())) + timesMap.put(entry.getKey(), Log.gct()); + String timeKey = timesMap.get(entry.getKey()); + progress += System.lineSeparator(); + progress += " - " + timeKey + " -- Installing: " + new File(entry.getKey()).getName(); + progress += " " + getProgressPerc(0) + System.lineSeparator(); + break; + } + return progress; + } + + private static String getProgressPerc(double progress) { + String progressStr = "[" + Math.round(progress * 100) + "%]"; + return progressStr; + } + + private static String getProgressBar(double progress) { + int nProgressBar = (int) (progress * N_BINS); + String progressStr = "[" + TOTAL_PROGRESS_STRING.substring(0, nProgressBar) + + TOTAL_REMAINING_STRING.substring(nProgressBar) + "] " + Math.round(progress * 100) + "%"; + return progressStr; + } + +} diff --git a/src/main/java/deepimagej/stamp/AbstractStamp.java b/src/main/java/deepimagej/stamp/AbstractStamp.java deleted file mode 100755 index 5e13e6dd..00000000 --- a/src/main/java/deepimagej/stamp/AbstractStamp.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Dimension; - -import javax.swing.JPanel; - -import deepimagej.BuildDialog; -import ij.IJ; - -public abstract class AbstractStamp { - - protected BuildDialog parent; - protected JPanel panel; - - public AbstractStamp(BuildDialog parent) { - this.parent = parent; - panel = new JPanel(new BorderLayout()); - } - - public abstract void buildPanel(); - - public abstract void init(); - - public abstract boolean finish(); - - public JPanel getPanel() { - return panel; - } - - public void closePlugin() { - panel.removeAll(); - } - -} diff --git a/src/main/java/deepimagej/stamp/InformationStamp.java b/src/main/java/deepimagej/stamp/InformationStamp.java deleted file mode 100755 index c69bac3b..00000000 --- a/src/main/java/deepimagej/stamp/InformationStamp.java +++ /dev/null @@ -1,731 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; -import java.awt.BorderLayout; -import java.awt.Color; -import java.awt.Component; -import java.awt.Container; -import java.awt.Dimension; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.Insets; -import java.awt.TextField; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.net.URL; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Vector; - -import javax.swing.BorderFactory; -import javax.swing.DefaultListCellRenderer; -import javax.swing.DefaultListModel; -import javax.swing.JButton; -import javax.swing.JComponent; -import javax.swing.JFrame; -import javax.swing.JLabel; -import javax.swing.JList; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextArea; -import javax.swing.JTextField; -import javax.swing.ListSelectionModel; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.components.HTMLPane; -import ij.IJ; -import ij.gui.GenericDialog; - -public class InformationStamp extends AbstractStamp implements ActionListener { - - public JTextField txtName = new JTextField("", 24); - - public JTextField txtAuth = new JTextField("", 24); - public JTextField txtTag = new JTextField("", 24); - - public JTextField txtDocumentation = new JTextField("", 24); - public JTextField txtGitRepo = new JTextField("", 24); - public JTextField txtLicense = new JTextField("", 24); - public JTextArea txtDescription = new JTextArea("", 3, 24); - - public JList> authList = new JList>(); - public JList tagList = new JList(); - public JList> citationList = new JList>(); - - private DefaultListModel> authModel; - private DefaultListModel tagModel; - private DefaultListModel> citationModel; - - public JButton authAddBtn = new JButton("Add"); - public JButton authRmvBtn = new JButton("Remove"); - - public JButton tagAddBtn = new JButton("Add"); - public JButton tagRmvBtn = new JButton("Remove"); - - public JButton citationAddBtn = new JButton("Add"); - public JButton citationRmvBtn = new JButton("Remove"); - - public ArrayList> introducedAuth = new ArrayList>(); - public ArrayList introducedTag = new ArrayList(); - public ArrayList> introducedCitation = new ArrayList>(); - - // Key words for a class method to know whether to build - // a citation panel or an authorship panel - private String authTag = "auth"; - private String citeTag = "cite"; - - // Parameter to keep track of the model being used - public String model = ""; - - public InformationStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - @Override - public void buildPanel() { - HTMLPane pane = new HTMLPane(Constants.width, 80); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "General Information"); - pane.append("p", "This information will be stored in the config.yaml"); - pane.append("p", "Add the reference to properly cite the pretrained model."); - txtDescription.setBorder(BorderFactory.createLineBorder(Color.gray)); - - JFrame pnFr = new JFrame(); - Container pn = pnFr.getContentPane(); - pn.setLayout(new GridBagLayout()); - - GridBagConstraints labelC = new GridBagConstraints(); - labelC.gridwidth = 4; - labelC.gridheight = 1; - labelC.gridx = 0; - labelC.gridy = 0; - labelC.ipadx = 5; - labelC.weightx = 0.2; - - GridBagConstraints infoC = new GridBagConstraints(); - infoC.gridwidth = 20; - infoC.gridheight = 1; - infoC.gridx = 4; - infoC.gridy = 0; - infoC.ipadx = 5; - infoC.weightx = 0.8; - infoC.anchor = GridBagConstraints.CENTER; - infoC.fill = GridBagConstraints.BOTH; - infoC.insets = new Insets(10, 0, 10, 10); - - // MOdel name field - pn.add(new JLabel("Full name"), labelC); - pn.add(txtName, infoC); - - // Authorship field - labelC.gridy = 1; - labelC.ipadx = 0; - labelC.ipady = 0; - infoC.gridy = 1; - infoC.insets = new Insets(0, 0, 0, 0); - infoC.ipady = 50; - infoC.ipadx = 50; - JFrame authorsFr = createAddRemoveCitation(authAddBtn, authRmvBtn, authTag); - pn.add(new JLabel("Authors of the bundled model"), labelC); - authorsFr.getContentPane().setSize(8, 20); - pn.add((JComponent) authorsFr.getContentPane(), infoC); - - // Citation field - labelC.gridy = 2; - labelC.ipadx = 0; - labelC.ipady = 0; - infoC.gridy = 2; - infoC.insets = new Insets(0, 0, 0, 0); - infoC.ipady = 50; - infoC.ipadx = 50; - JFrame citationsFr = createAddRemoveCitation(citationAddBtn, citationRmvBtn, citeTag); - pn.add(new JLabel("Citations"), labelC); - citationsFr.getContentPane().setSize(8, 20); - pn.add((JComponent) citationsFr.getContentPane(), infoC); - - // MOdel description field - labelC.gridy = 4; - labelC.gridheight = 3; - labelC.ipadx = 50; - labelC.ipady = 50; - - infoC.gridy = 4; - infoC.gridheight = 3; - - infoC.ipady = 80; - infoC.ipady = 80; - infoC.ipadx = 0; - infoC.insets = new Insets(0, 0, 0, 0); - - pn.add(new JLabel("Description of the model"), labelC); - txtDescription.setLineWrap(true); - txtDescription.setWrapStyleWord(true); - JScrollPane txtScroller = new JScrollPane(txtDescription); - txtScroller.setPreferredSize(new Dimension(txtDescription.getPreferredSize().width, txtDescription.getPreferredSize().height + 50)); - - pn.add(txtScroller, infoC); - - // Docs field - labelC.gridy = 7; - labelC.gridheight = 1; - labelC.ipadx = 0; - labelC.ipady = 0; - infoC.gridy = 7; - infoC.gridheight = 1; - infoC.insets = new Insets(10, 0, 10, 10); - - infoC.ipady = 0; - infoC.ipadx = 0; - pn.add(new JLabel("Link to documentation"), labelC); - pn.add(txtDocumentation, infoC); - - // GIT repo link field - labelC.gridy = 8; - labelC.gridheight = 1; - labelC.ipadx = 0; - labelC.ipady = 0; - infoC.gridy = 8; - infoC.gridheight = 1; - infoC.insets = new Insets(10, 0, 10, 10); - - infoC.ipady = 0; - infoC.ipadx = 0; - pn.add(new JLabel("Link to Github repo"), labelC); - pn.add(txtGitRepo, infoC); - - // Next field - labelC.gridy = 9; - infoC.gridy = 9; - pn.add(new JLabel("Type of license"), labelC); - pn.add(txtLicense, infoC); - - // TAgs field - JFrame tagsFr = createAddRemoveFrame(txtTag, tagAddBtn, "tag", tagRmvBtn); - - labelC.gridy = 10; - labelC.ipadx = 60; - labelC.ipady = 60; - infoC.gridy = 10; - pn.add(new JLabel("Tags to describe the model in the Bioimage Model Zoo"), labelC); - pn.add((JComponent) tagsFr.getContentPane(), infoC); - - JPanel p = new JPanel(new BorderLayout()); - - JScrollPane scroll = new JScrollPane(); - pn.setPreferredSize(new Dimension(pn.getWidth() + 400, pn.getHeight() + 700)); - scroll.setPreferredSize(new Dimension(pn.getWidth() + 300, pn.getHeight() + 400)); - scroll.setViewportView(pn); - - p.add(pane, BorderLayout.NORTH); - p.add(scroll, BorderLayout.CENTER); - panel.add(p); - - - // Add the tad 'deepImageJ' to the tags field. This tag - // is not removable - tagModel = new DefaultListModel(); - tagModel.addElement("deepimagej"); - tagList.setModel(tagModel); - introducedTag.add("deepimagej"); - - authAddBtn.addActionListener(this); - authRmvBtn.addActionListener(this); - - citationAddBtn.addActionListener(this); - citationRmvBtn.addActionListener(this); - - tagAddBtn.addActionListener(this); - tagRmvBtn.addActionListener(this); - } - - @Override - public void init() { - File file = new File(parent.getDeepPlugin().params.path2Model); - if (!model.equals(parent.getDeepPlugin().params.path2Model)) { - txtName.setText(file.getName()); - model = parent.getDeepPlugin().params.path2Model; - - introducedAuth = new ArrayList>(); - authModel = new DefaultListModel>(); - authList.setModel(authModel); - - introducedCitation = new ArrayList>(); - citationModel = new DefaultListModel>(); - citationList.setModel(citationModel); - // Add the tag 'deepImageJ' to the tags field. This tag - // is not removable - introducedTag = new ArrayList(); - tagModel = new DefaultListModel(); - tagModel.addElement("deepimagej"); - tagList.setModel(tagModel); - introducedTag.add("deepimagej"); - - // Reset all the fields - txtAuth.setText(""); - txtTag.setText(""); - txtDocumentation.setText(""); - txtGitRepo.setText(""); - txtLicense.setText(""); - txtDescription.setText(""); - } - } - - @Override - public boolean finish() { - if (txtName.getText().trim().equals("")) { - IJ.error("The name is a mandatory field"); - return false; - } - Parameters params = parent.getDeepPlugin().params; - params.name = txtName.getText().trim(); - - // TODO check if we need to cover here - params.documentation = txtDocumentation.getText().trim(); - params.git_repo = txtGitRepo.getText().trim(); - params.license = txtLicense.getText().trim(); - params.description = txtDescription.getText().trim(); - - params.name = params.name.equals("") ? null : coverForbiddenSymbols(params.name); - params.author = null; - if (introducedAuth.size() > 0) - params.author = introducedAuth; - params.cite = introducedCitation; - - params.documentation = params.documentation.equals("") ? null : params.documentation; - params.git_repo = params.git_repo.equals("") ? null : params.git_repo; - params.license = params.license.equals("") ? null : coverForbiddenSymbols(params.license); - params.description = params.description.equals("") ? null : coverForbiddenSymbols(params.description); - params.infoTags = introducedTag; - - - return true; - } - - // TODO find more forbidden characters - public static String coverForbiddenSymbols(String txt) { - String[] forbidenCharacters = {":", "{", "}", "[", "]", ">", "=", "!", - ",", "&", "*", "#", "?", "|", "-", "<", - "¡", "¿", "%", "@", "Ñ", "ñ"}; - for (String forbidenChar : forbidenCharacters) { - if (txt.contains(forbidenChar)) { - txt = "\'" + txt + "\'"; - break; - } - } - return txt; - } - - public void addAuthor() { - GenericDialog dlg = new GenericDialog("Add author information"); - dlg.addStringField("Name", "", 70); - dlg.addStringField("Affiliation", "", 70); - dlg.addStringField("Orcid", "", 70); - dlg.showDialog(); - if (dlg.wasCanceled()) { - return; - } - Vector strField = dlg.getStringFields(); - TextField nameField = (TextField) strField.get(0); - TextField affField = (TextField) strField.get(1); - TextField orcidField = (TextField) strField.get(2); - HashMap specs = new HashMap(); - specs.put("name", coverForbiddenSymbols(nameField.getText().trim())); - if (specs.get("name").contentEquals("")) - specs.put("name", "n/a"); - specs.put("affiliation", coverForbiddenSymbols(affField.getText().trim())); - if (specs.get("affiliation").contentEquals("")) - specs.put("affiliation", null); - specs.put("orcid", coverForbiddenSymbols(orcidField.getText().trim())); - if (specs.get("orcid").contentEquals("")) - specs.put("orcid", null); - - introducedAuth.add(specs); - - authModel = new DefaultListModel>(); - - // Add the elements to the list - - for (HashMap name : introducedAuth){ - authModel.addElement(name); - } - authList.setModel(authModel); - authList.setCellRenderer(new MyListCellRenderer(authTag)); - } - - public void removeAuthor() { - // Get the author selected - int citation = authList.getSelectedIndex(); - if (citation == -1) { - IJ.error("No citation selected"); - return; - } - introducedAuth.remove(citation); - - authModel = new DefaultListModel>(); - - // Add the elements to the list - - for (HashMap name : introducedAuth){ - authModel.addElement(name); - } - authList.setModel(authModel); - authList.setCellRenderer(new MyListCellRenderer(authTag)); - } - - public void addCite() { - GenericDialog dlg = new GenericDialog("Add reference and its doi"); - dlg.addStringField("Reference", "", 70); - dlg.addStringField("Doi", "http://", 70); - dlg.showDialog(); - if (dlg.wasCanceled()) { - return; - } - Vector strField = dlg.getStringFields(); - TextField refField = (TextField) strField.get(0); - TextField doiField = (TextField) strField.get(1); - HashMap refAndDoi = new HashMap(); - String txt = coverForbiddenSymbols(refField.getText().trim()); - refAndDoi.put("text", txt); - refAndDoi.put("doi", doiField.getText().trim()); - /* Try creating a valid URL */ - boolean url = false; - try { - new URL(refAndDoi.get("doi")).toURI(); - url = true; - } - - // If there was an Exception - // while creating URL object - catch (Exception e) { - url = false; - } - if (!url && !refAndDoi.get("doi").equals("")) { - IJ.error("You need to introduce a valid URL in the doi field or leave it empty."); - addCite(); - return; - } - introducedCitation.add(refAndDoi); - - citationModel = new DefaultListModel>(); - - // Add the elements to the list - - for (HashMap name : introducedCitation){ - citationModel.addElement(name); - } - citationList.setModel(citationModel); - citationList.setCellRenderer(new MyListCellRenderer(citeTag)); - } - - public void removeCite() { - // Get the author selected - int citation = citationList.getSelectedIndex(); - if (citation == -1) { - IJ.error("No citation selected"); - return; - } - introducedCitation.remove(citation); - - citationModel = new DefaultListModel>(); - - // Add the elements to the list - - for (HashMap name : introducedCitation){ - citationModel.addElement(name); - } - citationList.setModel(citationModel); - citationList.setCellRenderer(new MyListCellRenderer(citeTag)); - } - - public void addTag() { - // Get the author introduced - String tag = coverForbiddenSymbols(txtTag.getText().trim()); - if (tag.equals("")) { - IJ.error("Introduce a name"); - return; - } - introducedTag.add(tag); - - tagModel = new DefaultListModel(); - - // Add the elements to the list - - for (String name : introducedTag){ - tagModel.addElement(name); - } - tagList.setModel(tagModel); - txtTag.setText(""); - } - public void removeTag() { - // Get the author selected - int tag = tagList.getSelectedIndex(); - if (tag == -1) { - IJ.error("No tag selected"); - return; - } else if (tag == 0) { - IJ.error("Cannot remove 'deepimagej' tag"); - return; - } - introducedTag.remove(tag); - - tagModel = new DefaultListModel(); - - // Add the elements to the list - - for (String name : introducedTag){ - tagModel.addElement(name); - } - tagList.setModel(tagModel); - } - - /* - * Method that creates the Gui component that allows adding and removing tags - */ - public JFrame createAddRemoveFrame(JTextField txt, JButton add, String option, JButton rmv) { - // Create the panel to add authors - JFrame authorsFr = new JFrame(); - authorsFr.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); - Container authorsPn = authorsFr.getContentPane(); - authorsPn.setLayout(new GridBagLayout()); - - // creates a constraints object - GridBagConstraints c = new GridBagConstraints(); - c.fill = GridBagConstraints.BOTH; - c.ipady = 5; - c.ipadx = 20; - c.weightx = 1; - c.gridx = 0; - c.gridy = 0; - c.gridwidth = 7; - authorsPn.add(txt, c); - - c.ipady = 0; - c.ipadx = 0; - c.weightx = 0.2; - c.gridx = 7; - c.gridy = 0; - c.anchor = GridBagConstraints.CENTER; - c.fill = GridBagConstraints.NONE; - authorsPn.add(add, c); - - c.ipady = 40; - c.ipadx = 20; - c.weightx = 1; - c.weighty = 1; - c.gridwidth = 7; - c.anchor = GridBagConstraints.CENTER; - c.fill = GridBagConstraints.BOTH; - c.gridx = 0; - c.gridy = 1; - if (option.contains("auth")) { - //authModel = new DefaultListModel(); - //authModel.addElement(""); - //authList = new JList(authModel); - authList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - authList.setLayoutOrientation(JList.VERTICAL); - authList.setVisibleRowCount(2); - JScrollPane listScroller = new JScrollPane(authList); - listScroller.setPreferredSize(new Dimension(Constants.width, panel.getPreferredSize().height)); - authorsPn.add(listScroller, c); - } else if(option.contains("tag")) { - tagModel = new DefaultListModel(); - tagModel.addElement(""); - tagList = new JList(tagModel); - tagList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - tagList.setLayoutOrientation(JList.VERTICAL); - tagList.setVisibleRowCount(2); - JScrollPane listScroller = new JScrollPane(tagList); - listScroller.setPreferredSize(new Dimension(Constants.width, panel.getPreferredSize().height)); - authorsPn.add(listScroller, c); - } - - c.ipady = 0; - c.ipadx = 0; - c.gridx = 7; - c.gridy = 1; - c.gridheight =1; - c.anchor = GridBagConstraints.CENTER; - c.fill = GridBagConstraints.NONE; - c.weightx = 0.2; - Dimension btnDims = authAddBtn.getPreferredSize(); - rmv.setPreferredSize(btnDims); - authorsPn.add(rmv, c); - authorsFr.pack(); - return authorsFr; - } - - /* - * Method that creates the Gui component that allows adding and removing citations - */ - public JFrame createAddRemoveCitation(JButton add, JButton rmv, String option) { - // Create the panel to add authors - JFrame authorsFr = new JFrame(); - authorsFr.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); - Container authorsPn = authorsFr.getContentPane(); - authorsPn.setLayout(new GridBagLayout()); - - // creates a constraints object - GridBagConstraints c = new GridBagConstraints(); - c.fill = GridBagConstraints.BOTH; - c.ipady = 60; - c.ipadx = 20; - c.weightx = 1; - c.weighty = 1; - c.gridwidth = 7; - c.anchor = GridBagConstraints.CENTER; - c.fill = GridBagConstraints.BOTH; - c.gridx = 0; - c.gridy = 0; - c.gridheight =2; - - if (option.contains(authTag)) { - authModel = new DefaultListModel>(); - authList = new JList>(authModel); - authList.setCellRenderer(new MyListCellRenderer(authTag)); - authList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - authList.setLayoutOrientation(JList.VERTICAL); - authList.setVisibleRowCount(4); - JScrollPane listScroller = new JScrollPane(authList); - listScroller.setPreferredSize(new Dimension(Constants.width, panel.getPreferredSize().height)); - authorsPn.add(listScroller, c); - } else if(option.contains(citeTag)) { - citationModel = new DefaultListModel>(); - citationList = new JList>(citationModel); - citationList.setCellRenderer(new MyListCellRenderer(citeTag)); - citationList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - citationList.setLayoutOrientation(JList.VERTICAL); - citationList.setVisibleRowCount(4); - JScrollPane listScroller = new JScrollPane(citationList); - listScroller.setPreferredSize(new Dimension(Constants.width, panel.getPreferredSize().height)); - authorsPn.add(listScroller, c); - } - - c.ipady = 0; - c.ipadx = 0; - c.weightx = 0.2; - c.gridx = 7; - c.gridy = 0; - c.gridheight =1; - c.anchor = GridBagConstraints.CENTER; - c.fill = GridBagConstraints.NONE; - authorsPn.add(add, c); - - - c.ipady = 0; - c.ipadx = 0; - c.gridx = 7; - c.gridy = 1; - c.gridheight =1; - c.anchor = GridBagConstraints.CENTER; - c.fill = GridBagConstraints.NONE; - c.weightx = 0.2; - Dimension btnDims = authAddBtn.getPreferredSize(); - rmv.setPreferredSize(btnDims); - authorsPn.add(rmv, c); - authorsFr.pack(); - return authorsFr; - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == tagAddBtn) { - addTag(); - } - if (e.getSource() == tagRmvBtn) { - removeTag(); - } - if (e.getSource() == authAddBtn) { - addAuthor(); - } - if (e.getSource() == authRmvBtn) { - removeAuthor(); - } - if (e.getSource() == citationAddBtn) { - addCite(); - } - if (e.getSource() == citationRmvBtn) { - removeCite(); - } - } - - private class MyListCellRenderer extends DefaultListCellRenderer { - - private String tag; - - public MyListCellRenderer(String tag) { - this.tag = tag; - } - - @Override - public Component getListCellRendererComponent( - JList list, Object value, int index, - boolean isSelected, boolean cellHasFocus) { - super.getListCellRendererComponent(list, value, index, isSelected, cellHasFocus); - HashMap label = (HashMap) value; - if (tag.toLowerCase().contentEquals(citeTag)) { - String text = label.get("text"); - String doi = label.get("doi"); - if (label.keySet().size() > 0) { - String labelText = "- " + text + "
" + " " + doi; - setText(labelText); - } - } else if (tag.toLowerCase().contentEquals(authTag)) { - String name = label.get("name"); - String aff = label.get("affiliation"); - String orcid = label.get("orcid"); - if (label.keySet().size() > 0) { - String labelText = "- " + name + "
" + " " + aff + "
" + " " + orcid; - setText(labelText); - } - } - return this; - } - - } -} diff --git a/src/main/java/deepimagej/stamp/InputDimensionStamp.java b/src/main/java/deepimagej/stamp/InputDimensionStamp.java deleted file mode 100755 index 1b0cbae1..00000000 --- a/src/main/java/deepimagej/stamp/InputDimensionStamp.java +++ /dev/null @@ -1,501 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.Dimension; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.GridLayout; -import java.awt.Insets; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.util.ArrayList; -import java.util.List; - -import javax.swing.BorderFactory; -import javax.swing.BoxLayout; -import javax.swing.JButton; -import javax.swing.JComboBox; -import javax.swing.JLabel; -import javax.swing.JPanel; -import javax.swing.JTextField; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.components.GridPanel; -import deepimagej.components.HTMLPane; -import deepimagej.tools.DijTensor; -import deepimagej.tools.Index; -import ij.IJ; - -public class InputDimensionStamp extends AbstractStamp implements ActionListener { - - - private List allTxtMinSize = new ArrayList(); - private List allTxtStep = new ArrayList(); - - - private static String allowPatches = "Allow tilling"; - private static String notAllowPatches = "Do not allow tilling"; - - private JComboBox cmbPatches = new JComboBox(new String[] {allowPatches, notAllowPatches}); - // TODO rmove private JLabel lblPatches = new JLabel("Patch size"); - private JLabel lblMinSize = new JLabel("Minimum Size"); - private JLabel lblStep = new JLabel("Step Size"); - - private JButton bnNextOutput = new JButton("Next Output"); - private JButton bnPrevOutput = new JButton("Previous Output"); - private GridPanel pnInput = new GridPanel(); - - private static JComboBox cmbRangeLow = new JComboBox(new String [] {"-inf", "-1", "0", "1", "inf"}); - private static JComboBox cmbRangeHigh = new JComboBox(new String [] {"-inf", "-1", "0", "1", "inf"}); - private static double[] rangeOptions = {Double.NEGATIVE_INFINITY, (double) -1, (double) 0, (double) 1, Double.POSITIVE_INFINITY}; - - private List imageTensors; - private static int inputCounter = 0; - // Parameters to know if something changed and we have to rebuild the GUI - private String model = ""; - private List savedInputs = null; - private boolean tiling; - - - // Whether we need to add or not action listener to the text fields - private boolean listenTxtField = false; - - public InputDimensionStamp(BuildDialog parent) { - super(parent); - buildPanel(); - cmbRangeHigh.setSelectedIndex(4); - } - - @Override - public void buildPanel() { - - HTMLPane info = new HTMLPane(Constants.width, 265); - info.append("h", "Input size constraints"); - info.append("p", "Input tile size (Q) : Input size of the model. If Allow tiling or" - + " Do not allow tiling (with variable input size) is selected, Q will automatically " - + "change for each image during the inference and suggested to the user in DeepImageJ Run."); - info.append("p", "Minimum size (m) : Size of the smallest input that the model can process."); - info.append("p", "Step (s) : If the network supports different input sizes, the size of each" - + " dimension of the input image (Q), has to be the result of 'Minimum size (m) + N * Step (s)'," - + " where N can be any positive integer."); - - info.append("h", "Tiling strategies"); - info.append("p", "Allow tiling: Q is editable by the user as long as it " - + "fulfills the step (s) and minimum (m) constraints. Large images are processed using a tiling strategy."); - info.append("p", "Do not allow tiling: The input size will be processed as a whole (no tiling). Depending " - + "on the step (s) and minimum (m) constraints, the model might not be applicable to " - + "some images (too big or small)."); - - JPanel buttons = new JPanel(new GridLayout(1, 2)); - buttons.setBorder(BorderFactory.createEtchedBorder()); - buttons.add(bnPrevOutput); - buttons.add(bnNextOutput); - - // Create auxiliary DijTensor to initialise the interface - DijTensor auxTensor = new DijTensor("aux"); - auxTensor.tensorType = "image"; - auxTensor.tensor_shape = new int[5]; - auxTensor.form = "BZYXC"; - boolean start = true; - buildPanelForImage(auxTensor, start); - - JPanel pn = new JPanel(); - pn.setLayout(new BoxLayout(pn, BoxLayout.PAGE_AXIS)); - pn.setLayout(new GridBagLayout()); - GridBagConstraints c = new GridBagConstraints(); - c.gridheight = 8; - c.gridy = 0; - c.weightx = 1; - c.ipady = 0; - c.fill = GridBagConstraints.BOTH; - c.insets = new Insets(0, 0, 0, 0); - pn.add(info.getPane(), c); - c.gridheight = 8; - c.weightx = 1; - c.gridy = 8; - c.ipady = 20; - c.insets = new Insets(0, 0, 0, 0); - pn.add(pnInput, c); - c.gridheight = 1; - c.weightx = 1; - c.gridy = 16; - c.ipady = 0; - c.insets = new Insets(0, 0, 0, 0); - pn.add(buttons, c); - /* - pn.add(info.getPane()); - pn.add(pnInput); - pn.add(buttons, BorderLayout.SOUTH); - */ - - panel.add(pn); - - bnNextOutput.addActionListener(this); - bnPrevOutput.addActionListener(this); - } - - @Override - public void init() { - //String modelOfInterest = parent.getDeepPlugin().params.path2Model; - Parameters params = parent.getDeepPlugin().params; - imageTensors = DijTensor.getImageTensors(params.inputList); - // Set the screen at the first input if the model changes - String modelOfInterest = params.path2Model; - // Repaint interface if model has changed - if (!modelOfInterest.equals(model)) { - showCorrespondingInputInterface(params); - // To avoid referencing - tiling = true == params.allowPatching; - savedInputs = DijTensor.copyTensorList(params.inputList); - model = modelOfInterest; - inputCounter = 0; - return; - } - // Repaint interface if the number of input tensors has changed - if (params.inputList.size() != savedInputs.size()) { - showCorrespondingInputInterface(params); - // To avoid referencing - tiling = true == params.allowPatching; - savedInputs = DijTensor.copyTensorList(params.inputList); - inputCounter = 0; - return; - } - // If the number of inputs is the same, check if their shape or type have changed - for (int i = 0; i < params.inputList.size(); i ++) { - if (!params.inputList.get(i).tensorType.equals(savedInputs.get(i).tensorType) - || !params.inputList.get(i).form.equals(savedInputs.get(i).form)) { - showCorrespondingInputInterface(params); - // To avoid referencing - tiling = true == params.allowPatching; - savedInputs = DijTensor.copyTensorList(params.inputList); - inputCounter = 0; - return; - } - } - // If the model has changed from allow patching to not allow patching - // or from pyramidal to not pyramidal, repaint again. - if (tiling != params.allowPatching) { - showCorrespondingInputInterface(params); - // To avoid referencing - tiling = true == params.allowPatching; - inputCounter = 0; - return; - } - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - saveInputData(params); - for (DijTensor inp : params.inputList) { - if (inp.tensorType.contains("image") && !inp.finished) - return false; - } - // Save to know when to repaint the interface - savedInputs = params.inputList; - tiling = params.allowPatching; - return true; - } - - public void showCorrespondingInputInterface(Parameters params) { - - // Check how many outputs there are to enable or not - // the "next" and "back" buttons - int nImageTensors = imageTensors.size(); - if (inputCounter == 0) { - bnPrevOutput.setEnabled(false); - } else { - //bnPrevOutput.setEnabled(true); - bnPrevOutput.setEnabled(false); - } - if (inputCounter < (nImageTensors - 1)) { - //bnNextOutput.setEnabled(true); - bnPrevOutput.setEnabled(false); - } else { - bnNextOutput.setEnabled(false); - } - - // Reinitialise all the params - allTxtMinSize = new ArrayList(); - allTxtStep = new ArrayList(); - pnInput.removeAll(); - DijTensor tensor = imageTensors.get(inputCounter); - if (tensor.tensorType.contains("image")) { - // Build the panel - // Set listenTxtField to false because we need to add new - // listeners to the new text fields - listenTxtField = false; - buildPanelForImage(tensor); - updateImageInterface(tensor); - } else { - inputCounter ++; - } - pnInput.revalidate(); - pnInput.repaint(); - - } - - private void updateImageInterface(DijTensor tensor) { - - String[] dim = DijTensor.getWorkingDims(tensor.form); - int[] dimValues = DijTensor.getWorkingDimValues(tensor.form, tensor.tensor_shape); - cmbPatches.setEnabled(true); - - - // Allow patch decomposition - for (int i = 0; i < dim.length; i ++) { - if (dimValues[i] != -1 && !listenTxtField) { - allTxtMinSize.get(i).setText("" + dimValues[i]); - allTxtMinSize.get(i).setEditable(false); - allTxtStep.get(i).setText("" + 0); - allTxtStep.get(i).setEditable(false); - } else if (dimValues[i] == -1){ - allTxtMinSize.get(i).setEditable(true); - String stepGuess = dim[i].equals("C") ? "0" : "1"; - allTxtStep.get(i).setText(stepGuess); - allTxtStep.get(i).setEditable(true); - } - } - } - - /* - * Method to retrieve from the UI the information necessary to build - * whatever object is needed for the input tensor - */ - public boolean saveInputData(Parameters params) { - // If the methods saving the info were successful, wasSaved=true - boolean wasSaved = false; - ArrayList imageTensorInds = new ArrayList(); - for (int i = 0; i < params.inputList.size(); i ++) { - if (params.inputList.get(i).tensorType.contains("image")) - imageTensorInds.add(i); - } - int trueInd = imageTensorInds.get(inputCounter); - DijTensor tensor = params.inputList.get(trueInd); - if (imageTensorInds.size() > 0 && tensor.tensorType.contains("image")) { - wasSaved = saveInputDataForImage(params, tensor); - } - - int lowInd = cmbRangeLow.getSelectedIndex(); - int highInd = cmbRangeHigh.getSelectedIndex(); - if (lowInd >= highInd) { - IJ.error("The Data Range has to go from a value to a higher one."); - return false; - } - - params.inputList.get(trueInd).dataRange[0] = rangeOptions[lowInd]; - params.inputList.get(trueInd).dataRange[1] = rangeOptions[highInd]; - params.inputList.get(trueInd).finished = wasSaved; - - return wasSaved; - } - - /* - * Method to retrieve from the UI the information necessary to build an - * image from the tensor inputed to the model - */ - public boolean saveInputDataForImage(Parameters params, DijTensor tensor) { - params.fixedInput = true; - int[] min_size = new int[tensor.form.length()]; - int[] step = new int[tensor.form.length()]; - - int batchInd = Index.indexOf(tensor.form.split(""), "B"); - if (batchInd != -1) { - min_size[batchInd] = 1; step[batchInd] = 0; - } - - // Selected tiling option - String selection = (String) cmbPatches.getSelectedItem(); - - boolean auxDetectError = true; - try { - int auxCount = 0; - for (int c = 0; c < tensor.tensor_shape.length; c ++) { - if (c != batchInd) { - auxDetectError = false; - min_size[c] = Integer.parseInt(allTxtMinSize.get(auxCount).getText()); - auxDetectError = true; - step[c] = Integer.parseInt(allTxtStep.get(auxCount).getText()); - if (min_size[c] <= 0) { - IJ.error("The step should be larger than 0"); - return false; - } - if (step[c] < 0) { - IJ.error("The step size should be larger or equal to 0"); - return false; - } - auxCount ++; - } - } - } - catch (Exception ex) { - if (auxDetectError) { - IJ.error("The patch size is not a correct integer"); - } else if (!auxDetectError) { - IJ.error("The step is not a correct integer"); - } - return false; - } - - - if (selection.contains(allowPatches)) { - params.allowPatching = true; - } else if (selection.contains(notAllowPatches)) { - // The patch is always the same. No step because - // no more sizes are allowed - params.allowPatching = false; - } - for (int ss : step) { - if (ss != 0) { - params.fixedInput = false; - break; - } - } - - tensor.minimum_size = min_size; - tensor.step = step; - - return true; - } - - private void buildPanelForImage(DijTensor tensor) { - buildPanelForImage(tensor, false); - } - - private void buildPanelForImage(DijTensor tensor, boolean start) { - // Build panel for when the input tensor is an image - - allTxtMinSize = new ArrayList(); - allTxtStep = new ArrayList(); - String[] dims = DijTensor.getWorkingDims(tensor.form); - - JPanel pnMinSize = new JPanel(new GridLayout(2, dims.length)); - JPanel pnStep = new JPanel(new GridLayout(2, dims.length)); - - for (String dim: dims) { - JLabel dimLetter1 = new JLabel(dim); - dimLetter1.setPreferredSize( new Dimension( 10, 20 )); - JLabel dimLetter2 = new JLabel(dim); - dimLetter2.setPreferredSize( new Dimension( 10, 20 )); - - pnMinSize.add(dimLetter1); - pnStep.add(dimLetter2); - - } - - for (int i = 0; i < dims.length; i ++) { - JTextField txtMultiple = new JTextField("1", 5); - txtMultiple.setPreferredSize( new Dimension( 10, 20 )); - JTextField txtStep = new JTextField("0", 5); - txtStep.setPreferredSize( new Dimension( 10, 20 )); - - pnMinSize.add(txtMultiple); - allTxtMinSize.add(txtMultiple); - - pnStep.add(txtStep); - allTxtStep.add(txtStep); - } - - // If we are building the screen at the start of the plugin, - // and there are no params defined, just start by default - if (!start) { - Parameters params = parent.getDeepPlugin().params; - if ((!params.allowPatching || params.pyramidalNetwork)) { - // If we do not allow patching, do no show the corresponding options in the combobox - cmbPatches = new JComboBox(new String[] {notAllowPatches}); - } else if (params.allowPatching && !params.pyramidalNetwork) { - // If we allow patching but the size is fixed by the model, do not show the two options that - // allow freedom in the input image - cmbPatches = new JComboBox(new String[] {allowPatches, notAllowPatches}); - } - } else { - cmbPatches = new JComboBox(new String[] { allowPatches, notAllowPatches}); - } - - pnInput.removeAll(); - pnInput.setBorder(BorderFactory.createEtchedBorder()); - pnInput.place(0, 0, 2, 1, new JLabel("Name: " + tensor.name + " Input type: " + tensor.tensorType)); - pnInput.place(1, 0, 2, 1, cmbPatches); - pnInput.place(2, 0, lblMinSize); - pnInput.place(2, 1, pnMinSize); - pnInput.place(3, 0, lblStep); - pnInput.place(3, 1, pnStep); - GridPanel pnRange1 = new GridPanel(true); - pnRange1.place(0, 0, new JLabel("Data Range lower bound")); - pnRange1.place(0, 1, cmbRangeLow); - pnInput.place(4, 0, pnRange1); - GridPanel pnRange2 = new GridPanel(true); - pnRange2.place(0, 0, new JLabel("Data Range higher bound")); - pnRange2.place(0, 1, cmbRangeHigh); - pnInput.place(4, 1, pnRange2); - - cmbRangeLow.setEditable(false); - cmbRangeHigh.setEditable(false); - - cmbPatches.addActionListener(this); - } - - @Override - public void actionPerformed(ActionEvent e) { - Parameters params = parent.getDeepPlugin().params; - if (e.getSource() == bnNextOutput) { - if (saveInputData(params)) { - inputCounter ++; - showCorrespondingInputInterface(params); - } - } - if (e.getSource() == bnPrevOutput) { - inputCounter --; - showCorrespondingInputInterface(params); - } - if (e.getSource() == cmbPatches) { - updateImageInterface(params.inputList.get(inputCounter)); - } - } - -} diff --git a/src/main/java/deepimagej/stamp/JavaPostprocessingStamp.java b/src/main/java/deepimagej/stamp/JavaPostprocessingStamp.java deleted file mode 100755 index 609942b4..00000000 --- a/src/main/java/deepimagej/stamp/JavaPostprocessingStamp.java +++ /dev/null @@ -1,427 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Color; -import java.awt.Dimension; -import java.awt.FlowLayout; -import java.awt.Font; -import java.awt.GridBagConstraints; -import java.awt.GridLayout; -import java.awt.Insets; -import java.awt.Panel; -import java.awt.datatransfer.DataFlavor; -import java.awt.datatransfer.Transferable; -import java.awt.datatransfer.UnsupportedFlavorException; -import java.awt.dnd.DnDConstants; -import java.awt.dnd.DropTarget; -import java.awt.dnd.DropTargetDropEvent; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import javax.swing.BorderFactory; -import javax.swing.DefaultListModel; -import javax.swing.JButton; -import javax.swing.JFileChooser; -import javax.swing.JFrame; -import javax.swing.JList; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextField; -import javax.swing.ListSelectionModel; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.components.BorderPanel; -import deepimagej.components.HTMLPane; -import ij.IJ; -import ij.gui.GenericDialog; - -public class JavaPostprocessingStamp extends AbstractStamp implements ActionListener { - - private JTextField txt1 = new JTextField("Drop zone for the first postprocessing file"); - private JTextField txt2 = new JTextField("Drop zone for the second postprocessing file"); - private JButton bnBrowse1 = new JButton("Browse"); - private JButton bnBrowse2 = new JButton("Browse"); - - private static JTextField depPath = new JTextField("Introduce/drop post-processing dependecy jar file"); - public static JList dependenciesList = new JList(); - private static DefaultListModel dependenciesModel; - public static JButton addBtn = new JButton("Add"); - public static JButton rmvBtn = new JButton("Remove"); - - // Variable to keep track of the model being used - private String model = ""; - - public JavaPostprocessingStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - @Override - public void buildPanel() { - - HTMLPane pane = new HTMLPane(Constants.width, 90); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "External Postprocessing"); - pane.append("p", "(Optional) Add the required preprocessing for the input image.\n" - + "It supports ImageJ macro routines or Java code.\n" + - "Macro routines allow '.txt' or '.ijm' extensions."); - pane.append("p", "The Java code can be included with either '.class' or '.jar' files"); - - txt1.setFont(new Font("Arial", Font.BOLD, 14)); - txt1.setForeground(Color.red); - //txt1.setPreferredSize(new Dimension(Constants.width, 25)); - JPanel load1 = new JPanel(new BorderLayout()); - load1.setBorder(BorderFactory.createEtchedBorder()); - load1.add(txt1, BorderLayout.CENTER); - load1.add(bnBrowse1, BorderLayout.EAST); - - txt2.setFont(new Font("Arial", Font.BOLD, 14)); - txt2.setForeground(Color.red); - //txt2.setPreferredSize(new Dimension(Constants.width, 25)); - JPanel load2 = new JPanel(new BorderLayout()); - load2.setBorder(BorderFactory.createEtchedBorder()); - load2.add(txt2, BorderLayout.CENTER); - load2.add(bnBrowse2, BorderLayout.EAST); - - JPanel load = new JPanel(new GridLayout(2,0)); - load.add(load1); - load.add(load2); - - JPanel pn = new JPanel(new BorderLayout()); - pn.add(pane.getPane(), BorderLayout.NORTH); - - pn.add(load, BorderLayout.CENTER); - - panel.add(pn); - - txt1.setDropTarget(new LocalDropTarget(txt1)); - load1.setDropTarget(new LocalDropTarget(txt1)); - bnBrowse1.addActionListener(this); - - txt2.setDropTarget(new LocalDropTarget(txt2)); - load2.setDropTarget(new LocalDropTarget(txt2)); - bnBrowse2.addActionListener(this); - - // Action listeners for the build path GUI. - // This GUI only appears if a Java processing is included - depPath.setDropTarget(new LocalDropTarget(depPath)); - addBtn.addActionListener(this); - rmvBtn.addActionListener(this); - } - - @Override - public void init() { - Parameters params = parent.getDeepPlugin().params; - if (params.firstPostprocessing == null || !model.contentEquals(params.path2Model)) - txt1.setText("Drop zone for the first postprocessing"); - if (params.secondPostprocessing == null || !model.contentEquals(params.path2Model)) - txt2.setText("Drop zone for the second postprocessing"); - if (!model.contentEquals(params.path2Model)) { - model = params.path2Model; - } - parent.getDeepPlugin().params.postAttachments = new ArrayList(); - dependenciesModel = new DefaultListModel(); - dependenciesList.setModel(dependenciesModel); - - } - - @Override - public boolean finish() { - String filename1 = txt1.getText(); - String filename2 = txt2.getText(); - parent.getDeepPlugin().params.firstPostprocessing = null; - parent.getDeepPlugin().params.secondPostprocessing = null; - if (filename1.contains(File.separator)) { - File file1 = new File(filename1); - if (!file1.exists()) { - IJ.error("This directory " + filename1 + " doesn't exist"); - return false; - } - if ((file1.isFile()) && (!file1.getAbsolutePath().contains(".txt") && !file1.getAbsolutePath().contains(".ijm")) && (!file1.getAbsolutePath().contains(".class")) && (!file1.getAbsolutePath().contains(".jar"))) { - IJ.error("The path " + filename1 + " does not corresponf to a valid macro or Java file"); - return false; - } - parent.getDeepPlugin().params.firstPostprocessing = filename1; - } - - if (filename2.contains(File.separator)) { - File file2 = new File(filename2); - if (!file2.exists()) { - IJ.error("This directory " + filename2 + " doesn't exist"); - return false; - } - if ((file2.isFile()) && (!file2.getAbsolutePath().contains(".txt") && !file2.getAbsolutePath().contains(".ijm")) && (!file2.getAbsolutePath().contains(".class")) && (!file2.getAbsolutePath().contains(".jar"))) { - IJ.error("The path " + filename2 + " does not corresponf to a valid macro or Java file"); - return false; - } - parent.getDeepPlugin().params.secondPostprocessing = filename2; - } - boolean result = true; - if (filename1.endsWith(".jar") || filename1.endsWith(".class") || filename2.endsWith(".jar") || filename2.endsWith(".class")) { - result = addJavaDependencies(); - } - return result; - } - - public class LocalDropTarget extends DropTarget { - private JTextField id; - public LocalDropTarget(JTextField id) { - this.id = id; - } - - @Override - public void drop(DropTargetDropEvent e) { - e.acceptDrop(DnDConstants.ACTION_COPY); - e.getTransferable().getTransferDataFlavors(); - Transferable transferable = e.getTransferable(); - DataFlavor[] flavors = transferable.getTransferDataFlavors(); - for (DataFlavor flavor : flavors) { - if (flavor.isFlavorJavaFileListType()) { - try { - List files = (List) transferable.getTransferData(flavor); - for (File file : files) { - id.setText(file.getAbsolutePath()); - id.setCaretPosition(1); - } - } - catch (UnsupportedFlavorException ex) { - ex.printStackTrace(); - } - catch (IOException ex) { - ex.printStackTrace(); - } - } - } - e.dropComplete(true); - super.drop(e); - } - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == bnBrowse1) { - browse(true); - } else if (e.getSource() == bnBrowse2) { - browse(false); - } else if (e.getSource() == addBtn) { - addDependency(); - } else if (e.getSource() == rmvBtn) { - removeDependency(); - } - } - - private void browse(boolean firstProcessing) { - JFileChooser chooser = new JFileChooser(txt1.getText()); - //chooser.setFileSelectionMode(JFileChooser.FILES_ONLY); - chooser.setDialogTitle("Select preprocessing jar"); - int ret = chooser.showOpenDialog(new JFrame()); - if (ret == JFileChooser.APPROVE_OPTION) { - if (firstProcessing) { - txt1.setText(chooser.getSelectedFile().getAbsolutePath()); - txt1.setCaretPosition(1); - } else { - txt2.setText(chooser.getSelectedFile().getAbsolutePath()); - txt2.setCaretPosition(1); - } - } - } - - /** - * Opens GUI to link Java dependencies - */ - public boolean addJavaDependencies() { - - GenericDialog dlg = new GenericDialog("Java Build Path and external files"); - dlg.addMessage("Add path to the Java .jar dependencies needed to run the pre-processing."); - dlg.addMessage("You can also add files required for the execution of the Java code, such as config files."); - dlg.addMessage("The formats allowed for Java dependencies are '.class' and '.jar'."); - dlg.addMessage("If there are no dependencies or files needed simply press 'OK'."); - - Panel loadPath = new Panel(); - loadPath.setLayout(new FlowLayout()); - loadPath.add(depPath, BorderLayout.CENTER); - depPath.setText("Drop file needed for post-processing"); - depPath.setFont(new Font("Arial", Font.BOLD, 11)); - depPath.setForeground(Color.GRAY); - depPath.setPreferredSize(new Dimension(300, 50)); - - // Panel for buttons - JPanel buttons = new JPanel(); - buttons.setLayout(new GridLayout()); - buttons.add(addBtn); - buttons.add(rmvBtn); - - loadPath.add(buttons, BorderLayout.EAST); - loadPath.setVisible(true); - dlg.addPanel(loadPath, GridBagConstraints.CENTER, new Insets(5, 0, 5, 0)); - Dimension panelSize = loadPath.getPreferredSize(); - - BorderPanel panel = new BorderPanel(); - dependenciesModel = new DefaultListModel(); - dependenciesModel.addElement(""); - dependenciesList = new JList(dependenciesModel); - dependenciesList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - dependenciesList.setLayoutOrientation(JList.VERTICAL); - dependenciesList.setVisibleRowCount(2); - JScrollPane listScroller = new JScrollPane(dependenciesList); - panel.add(listScroller); - dlg.addPanel(panel, GridBagConstraints.CENTER, new Insets(5, 0, 0, 0)); - - - loadPath.setPreferredSize(new Dimension((int) Math.round(panelSize.getWidth() * 1), (int) Math.round(panelSize.getHeight() * 1))); - listScroller.setPreferredSize(new Dimension((int) Math.round(panelSize.getWidth() * 1), (int) Math.round(panelSize.getHeight() * 2))); - - dlg.showDialog(); - - if (dlg.wasOKed()) { - parent.getDeepPlugin().params.attachments = new ArrayList(); - for (String pre : parent.getDeepPlugin().params.preAttachments) - parent.getDeepPlugin().params.attachments.add(pre); - for (String post : parent.getDeepPlugin().params.postAttachments) { - if (!parent.getDeepPlugin().params.attachments.contains(post)) - parent.getDeepPlugin().params.attachments.add(post); - } - return true; - } - - return false; - } - - /** - * Adds dependency introduced to the list. Only accept .jar file - */ - public void addDependency() { - // Get the author introduced - String tag = depPath.getText().trim(); - if (tag.equals("")) { - IJ.error("Introduce the path to an external file."); - // Empty the text in the text field - depPath.setText("Drop file needed for post-processing"); - return; - } else if(!(new File(tag)).isFile()) { - IJ.error("The path introduced does not correspond to an existing file."); - // Empty the text in the text field - depPath.setText("Drop file needed for post-processing"); - return; - } - for (String dep : parent.getDeepPlugin().params.postAttachments) { - if (dep.contentEquals(tag)) { - IJ.error("Do not add the same file twice."); - // Empty the text in the text field - depPath.setText("Drop file needed for post-processing"); - return; - } - } - // Check that the name of the files introduced does not coincide - // with the name of a file given during pre-processing, unless it - // is the same file - String postName = tag.substring(tag.lastIndexOf(File.separator) + 1); - for (String dep : parent.getDeepPlugin().params.preAttachments) { - String preName = dep.substring(dep.lastIndexOf(File.separator) + 1); - if (preName.contentEquals(postName) && !dep.contentEquals(tag) && !tag.endsWith(".jar")) { - IJ.error("A file called '" + postName + "' was already added for pre-processing.\n" - + "Cannot add file with the same filename unless it is the exact same file (same path)."); - // Empty the text in the text field - depPath.setText("Drop file needed for post-processing"); - return; - } - } - // Check that the name of the files introduced does not coincide - // with the name of a file given during post-processing, - for (String dep : parent.getDeepPlugin().params.postAttachments) { - String pName = dep.substring(dep.lastIndexOf(File.separator) + 1); - if (pName.contentEquals(postName) && !dep.contentEquals(tag) && !tag.endsWith(".jar")) { - IJ.error("A file called '" + postName + "' was already added for post-processing.\n" - + "Cannot add two files with the same name."); - // Empty the text in the text field - depPath.setText("Drop file needed for post-processing"); - return; - } - } - - parent.getDeepPlugin().params.postAttachments.add(tag); - - dependenciesModel = new DefaultListModel(); - - // Add the elements to the list - - for (String name : parent.getDeepPlugin().params.postAttachments){ - dependenciesModel.addElement(name); - } - dependenciesList.setModel(dependenciesModel); - // Empty the text in the text field - depPath.setText("Drop file needed for post-processing"); - } - - /** - * Remove dependency previoulsy introduced - */ - public void removeDependency() { - // Get the author selected - int tag = dependenciesList.getSelectedIndex(); - if (tag == -1) { - IJ.error("No file selected to remove"); - return; - } - parent.getDeepPlugin().params.postAttachments.remove(tag); - - dependenciesModel = new DefaultListModel(); - - // Add the elements to the list - - for (String name : parent.getDeepPlugin().params.postAttachments){ - dependenciesModel.addElement(name); - } - dependenciesList.setModel(dependenciesModel); - } -} diff --git a/src/main/java/deepimagej/stamp/JavaPreprocessingStamp.java b/src/main/java/deepimagej/stamp/JavaPreprocessingStamp.java deleted file mode 100755 index 2cf0e75a..00000000 --- a/src/main/java/deepimagej/stamp/JavaPreprocessingStamp.java +++ /dev/null @@ -1,406 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Color; -import java.awt.Dimension; -import java.awt.FlowLayout; -import java.awt.Font; -import java.awt.GridBagConstraints; -import java.awt.GridLayout; -import java.awt.Insets; -import java.awt.Panel; -import java.awt.datatransfer.DataFlavor; -import java.awt.datatransfer.Transferable; -import java.awt.datatransfer.UnsupportedFlavorException; -import java.awt.dnd.DnDConstants; -import java.awt.dnd.DropTarget; -import java.awt.dnd.DropTargetDropEvent; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import javax.swing.BorderFactory; -import javax.swing.DefaultListModel; -import javax.swing.JButton; -import javax.swing.JFileChooser; -import javax.swing.JFrame; -import javax.swing.JList; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextField; -import javax.swing.ListSelectionModel; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.components.BorderPanel; -import deepimagej.components.HTMLPane; -import ij.IJ; -import ij.gui.GenericDialog; - -public class JavaPreprocessingStamp extends AbstractStamp implements ActionListener { - - private JTextField txt1 = new JTextField("Drop zone for the first preprocessing file"); - private JTextField txt2 = new JTextField("Drop zone for the second preprocessing file"); - private JButton bnBrowse1 = new JButton("Browse"); - private JButton bnBrowse2 = new JButton("Browse"); - - private static JTextField depPath = new JTextField(""); - public static JList dependenciesList = new JList(); - private static DefaultListModel dependenciesModel; - public static JButton addBtn = new JButton("Add"); - public static JButton rmvBtn = new JButton("Remove"); - - // Variable to keep track of the model being used - private String model = ""; - - public JavaPreprocessingStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - @Override - public void buildPanel() { - - HTMLPane pane = new HTMLPane(Constants.width, 90); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "External Preprocessing"); - pane.append("p", "(Optional) Add the required preprocessing for the input image.\n" - + "It supports ImageJ macro routines or Java code.\n" + - "Macro routines allow '.txt' or '.ijm' extensions."); - pane.append("p", "The Java code can be included with either '.class' or '.jar' files"); - - txt1.setFont(new Font("Arial", Font.BOLD, 14)); - txt1.setForeground(Color.red); - //txt1.setPreferredSize(new Dimension(Constants.width, 25)); - JPanel load1 = new JPanel(new BorderLayout()); - load1.setBorder(BorderFactory.createEtchedBorder()); - load1.add(txt1, BorderLayout.CENTER); - load1.add(bnBrowse1, BorderLayout.EAST); - - txt2.setFont(new Font("Arial", Font.BOLD, 14)); - txt2.setForeground(Color.red); - //txt2.setPreferredSize(new Dimension(Constants.width, 25)); - JPanel load2 = new JPanel(new BorderLayout()); - load2.setBorder(BorderFactory.createEtchedBorder()); - load2.add(txt2, BorderLayout.CENTER); - load2.add(bnBrowse2, BorderLayout.EAST); - - JPanel load = new JPanel(new GridLayout(2,0)); - load.add(load1); - load.add(load2); - - JPanel pn = new JPanel(new BorderLayout()); - pn.add(pane.getPane(), BorderLayout.NORTH); - - pn.add(load, BorderLayout.CENTER); - - panel.add(pn); - - txt1.setDropTarget(new LocalDropTarget(txt1)); - load1.setDropTarget(new LocalDropTarget(txt1)); - bnBrowse1.addActionListener(this); - - txt2.setDropTarget(new LocalDropTarget(txt2)); - load2.setDropTarget(new LocalDropTarget(txt2)); - bnBrowse2.addActionListener(this); - - // Action listeners for the build path GUI. - // This GUI only appears if a Java processing is included - depPath.setDropTarget(new LocalDropTarget(depPath)); - //loadPath.setDropTarget(new LocalDropTarget(depPath)); - addBtn.addActionListener(this); - rmvBtn.addActionListener(this); - - } - - @Override - public void init() { - Parameters params = parent.getDeepPlugin().params; - if (params.firstPreprocessing == null || !model.contentEquals(params.path2Model)) - txt1.setText("Drop zone for the first preprocessing"); - if (params.secondPreprocessing == null || !model.contentEquals(params.path2Model)) - txt2.setText("Drop zone for the second preprocessing"); - if (!model.contentEquals(params.path2Model)) { - model = params.path2Model; - } - params.preAttachments = new ArrayList(); - dependenciesModel = new DefaultListModel(); - dependenciesList.setModel(dependenciesModel); - } - - @Override - public boolean finish() { - String filename1 = txt1.getText(); - String filename2 = txt2.getText(); - parent.getDeepPlugin().params.firstPreprocessing = null; - parent.getDeepPlugin().params.secondPreprocessing = null; - if (filename1.contains(File.separator)) { - File file1 = new File(filename1); - if (!file1.exists()) { - IJ.error("This directory " + filename1 + " doesn't exist"); - return false; - } - if ((file1.isFile()) && (!file1.getAbsolutePath().contains(".txt") && !file1.getAbsolutePath().contains(".ijm")) && (!file1.getAbsolutePath().contains(".class")) && (!file1.getAbsolutePath().contains(".jar"))) { - IJ.error("The path " + filename1 + " does not corresponf to a valid macro or Java file"); - return false; - } - parent.getDeepPlugin().params.firstPreprocessing = filename1; - } - - if (filename2.contains(File.separator)) { - File file2 = new File(filename2); - if (!file2.exists()) { - IJ.error("This directory " + filename2 + " doesn't exist"); - return false; - } - if ((file2.isFile()) && (!file2.getAbsolutePath().contains(".txt") && !file2.getAbsolutePath().contains(".ijm")) && (!file2.getAbsolutePath().contains(".class")) && (!file2.getAbsolutePath().contains(".jar"))) { - IJ.error("The path " + filename2 + " does not corresponf to a valid macro or Java file"); - return false; - } - parent.getDeepPlugin().params.secondPreprocessing = filename2; - } - boolean result = true; - if (filename1.endsWith(".jar") || filename1.endsWith(".class") || filename2.endsWith(".jar") || filename2.endsWith(".class")) { - result = addJavaDependencies(); - } - return result; - } - - public class LocalDropTarget extends DropTarget { - private JTextField id; - public LocalDropTarget(JTextField id) { - this.id = id; - } - - @Override - public void drop(DropTargetDropEvent e) { - e.acceptDrop(DnDConstants.ACTION_COPY); - e.getTransferable().getTransferDataFlavors(); - Transferable transferable = e.getTransferable(); - DataFlavor[] flavors = transferable.getTransferDataFlavors(); - for (DataFlavor flavor : flavors) { - if (flavor.isFlavorJavaFileListType()) { - try { - List files = (List) transferable.getTransferData(flavor); - for (File file : files) { - id.setText(file.getAbsolutePath()); - id.setCaretPosition(1); - } - } - catch (UnsupportedFlavorException ex) { - ex.printStackTrace(); - } - catch (IOException ex) { - ex.printStackTrace(); - } - } - } - e.dropComplete(true); - super.drop(e); - } - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == bnBrowse1) { - browse(true); - } else if (e.getSource() == bnBrowse2) { - browse(false); - } else if (e.getSource() == addBtn) { - addDependency(); - } else if (e.getSource() == rmvBtn) { - removeDependency(); - } - } - - private void browse(boolean firstProcessing) { - JFileChooser chooser = new JFileChooser(txt1.getText()); - //chooser.setFileSelectionMode(JFileChooser.FILES_ONLY); - chooser.setDialogTitle("Select preprocessing jar"); - int ret = chooser.showOpenDialog(new JFrame()); - if (ret == JFileChooser.APPROVE_OPTION) { - if (firstProcessing) { - txt1.setText(chooser.getSelectedFile().getAbsolutePath()); - txt1.setCaretPosition(1); - } else { - txt2.setText(chooser.getSelectedFile().getAbsolutePath()); - txt2.setCaretPosition(1); - } - } - } - - /** - * Opens GUI to link Java dependencies - */ - public boolean addJavaDependencies() { - boolean result = false; - - GenericDialog dlg = new GenericDialog("Java Build Path and external files"); - dlg.addMessage("Add path to the Java .jar dependencies needed to run the pre-processing."); - dlg.addMessage("You can also add files required for the execution of the Java code, such as config files."); - dlg.addMessage("The formats allowed for Java dependencies are '.class' and '.jar'."); - dlg.addMessage("If there are no dependencies or files needed simply press 'OK'."); - - Panel loadPath = new Panel(); - loadPath.setLayout(new FlowLayout()); - loadPath.add(depPath); - depPath.setText("Drop file needed for pre-processing"); - depPath.setFont(new Font("Arial", Font.BOLD, 11)); - depPath.setForeground(Color.GRAY); - depPath.setPreferredSize(new Dimension(300, 50)); - - // Panel for buttons - JPanel buttons = new JPanel(); - buttons.setLayout(new GridLayout()); - buttons.add(addBtn); - buttons.add(rmvBtn); - loadPath.add(buttons, BorderLayout.EAST); - loadPath.setVisible(true); - dlg.addPanel(loadPath, GridBagConstraints.CENTER, new Insets(5, 0, 0, 0)); - Dimension panelSize = loadPath.getPreferredSize(); - - BorderPanel panel = new BorderPanel(); - dependenciesModel = new DefaultListModel(); - dependenciesModel.addElement(""); - dependenciesList = new JList(dependenciesModel); - dependenciesList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - dependenciesList.setLayoutOrientation(JList.VERTICAL); - dependenciesList.setVisibleRowCount(2); - JScrollPane listScroller = new JScrollPane(dependenciesList); - panel.add(listScroller); - dlg.addPanel(panel, GridBagConstraints.CENTER, new Insets(5, 0, 0, 0)); - - loadPath.setPreferredSize(new Dimension((int) Math.round(panelSize.getWidth() * 1), (int) Math.round(panelSize.getHeight() * 1))); - listScroller.setPreferredSize(new Dimension((int) Math.round(panelSize.getWidth() * 1), (int) Math.round(panelSize.getHeight() * 2))); - - dlg.showDialog(); - - if (dlg.wasOKed()) { - result = true; - } - return result; - } - - /** - * Adds dependency introduced to the list. Only accept .jar file - */ - public void addDependency() { - // Get the author introduced - String tag = depPath.getText().trim(); - if (tag.equals("")) { - IJ.error("Introduce the path to an external file."); - // Empty the text in the text field - depPath.setText("Drop file needed for pre-processing"); - return; - } else if(!(new File(tag)).isFile()) { - IJ.error("The path introduced does not correspond to an existing file."); - // Empty the text in the text field - depPath.setText("Drop file needed for pre-processing"); - return; - } - for (String dep : parent.getDeepPlugin().params.preAttachments) { - if (dep.contentEquals(tag)) { - IJ.error("Do not add the same file twice."); - // Empty the text in the text field - depPath.setText("Drop file needed for pre-processing"); - return; - } - } - // Check that the name of the files introduced does not coincide - // with the name of a file given during pre-processing - String preName = tag.substring(tag.lastIndexOf(File.separator) + 1); - for (String dep : parent.getDeepPlugin().params.preAttachments) { - String pName = dep.substring(dep.lastIndexOf(File.separator) + 1); - if (pName.contentEquals(preName) && !tag.endsWith(".jar")) { - IJ.error("A file called '" + preName + "' was already added for pre-processing.\n" - + "Cannot add two files with the same name."); - // Empty the text in the text field - depPath.setText("Drop file needed for pre-processing"); - return; - } - } - - parent.getDeepPlugin().params.preAttachments.add(tag); - - dependenciesModel = new DefaultListModel(); - - // Add the elements to the list - - for (String name : parent.getDeepPlugin().params.preAttachments){ - dependenciesModel.addElement(name); - } - dependenciesList.setModel(dependenciesModel); - // Empty the text in the text field - depPath.setText("Drop file needed for pre-processing"); - } - - /** - * Remove dependency previoulsy introduced - */ - public void removeDependency() { - // Get the author selected - int tag = dependenciesList.getSelectedIndex(); - if (tag == -1) { - IJ.error("No file selected to remove"); - return; - } - parent.getDeepPlugin().params.preAttachments.remove(tag); - - dependenciesModel = new DefaultListModel(); - - // Add the elements to the list - - for (String name : parent.getDeepPlugin().params.preAttachments){ - dependenciesModel.addElement(name); - } - dependenciesList.setModel(dependenciesModel); - } -} diff --git a/src/main/java/deepimagej/stamp/LoadPytorchStamp.java b/src/main/java/deepimagej/stamp/LoadPytorchStamp.java deleted file mode 100755 index ab31d07a..00000000 --- a/src/main/java/deepimagej/stamp/LoadPytorchStamp.java +++ /dev/null @@ -1,386 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; -import java.awt.BorderLayout; -import java.io.File; -import java.io.IOException; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Parameter; -import java.net.MalformedURLException; -import java.net.URL; -import java.util.ArrayList; -import java.util.Vector; -import java.util.concurrent.atomic.AtomicBoolean; - -import javax.swing.BoxLayout; -import javax.swing.JPanel; -import javax.swing.JTextField; - -import ai.djl.MalformedModelException; -import ai.djl.engine.Engine; -import ai.djl.engine.EngineException; -import ai.djl.ndarray.NDList; -import ai.djl.pytorch.jni.LibUtils; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.DeepLearningModel; -import deepimagej.Parameters; -import deepimagej.components.HTMLPane; -import deepimagej.tools.DijTensor; -import deepimagej.tools.SystemUsage; -import ij.IJ; -import ij.gui.GenericDialog; - -public class LoadPytorchStamp extends AbstractStamp implements Runnable { - - private JTextField inpNumber = new JTextField(); - private JTextField outNumber = new JTextField(); - - private HTMLPane pnLoad; - - - public LoadPytorchStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - @Override - public void buildPanel() { - pnLoad = new HTMLPane(Constants.width, 70); - - HTMLPane pnTag = new HTMLPane(Constants.width / 2, 70); - pnTag.append("h2", "Number of inputs"); - pnTag.append("p", "Number of inputs to the Pytorch model"); - - HTMLPane pnGraph = new HTMLPane(2 * Constants.width / 2, 70); - pnGraph.append("h2", "Number of outputs"); - pnGraph.append("p", "Number of outputs of the Pytorch model."); - - JPanel pn = new JPanel(); - pn.setLayout(new BoxLayout(pn, BoxLayout.PAGE_AXIS)); - pn.add(pnTag.getPane()); - pn.add(inpNumber); - inpNumber.setText("0"); - inpNumber.setEnabled(false); - pn.add(pnGraph.getPane()); - pn.add(outNumber); - outNumber.setText("0"); - outNumber.setEnabled(false); - - JPanel main = new JPanel(new BorderLayout()); - main.add(pnLoad.getPane(), BorderLayout.CENTER); - main.add(pn, BorderLayout.SOUTH); - panel.add(main); - } - - @Override - public void init() { - Thread thread = new Thread(this); - thread.setPriority(Thread.MIN_PRIORITY); - thread.start(); - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - params.totalInputList = new ArrayList(); - params.totalOutputList = new ArrayList(); - boolean inp = false; - try { - int nInp = Integer.parseInt(inpNumber.getText().trim()); - inp = true; - int nOut = Integer.parseInt(outNumber.getText().trim()); - if (nOut < 1) { - IJ.error("The number of outputs shoud be 1 or bigger"); - return false; - } else if (nInp < 1) { - IJ.error("The number of inputs shoud be 1 or bigger"); - return false; - } - for (int i = 0; i < nInp; i ++) { - // TODO when possible add dimensions from model - DijTensor inpT = new DijTensor("input" + i); - params.totalInputList.add(inpT); - } - for (int i = 0; i < nOut; i ++) { - // TODO when possible add dimensions from model - DijTensor outT = new DijTensor("output" + i); - params.totalOutputList.add(outT); - } - return true; - - } catch (Exception ex) { - if (!inp) { - IJ.error("Please introduce a valid integer for the number of inputs."); - } else if (inp) { - IJ.error("Please introduce a valid integer for the number of outputs."); - } - return false; - } - } - - public void run() { - pnLoad.setCaretPosition(0); - pnLoad.setText(""); - pnLoad.append("p", "Loading Deep Java Library..."); - - Parameters params = parent.getDeepPlugin().params; - params.selectedModelPath = findPytorchModels(params.path2Model); - pnLoad.clear(); - params.pytorchVersion = DeepLearningModel.getPytorchVersion(); - pnLoad.append("h2", "Pytorch version"); - pnLoad.append("p", "Currently using Pytorch " + params.pytorchVersion); - pnLoad.append("p", "Supported by Deep Java Library " + params.pytorchVersion); - String cudaVersion = SystemUsage.getCUDAEnvVariables(); - // If a CUDA distribution was found, cudaVersion will be equal - // to the CUDA version. If not it can be either 'noCuda', if CUDA - // is not installed, or if there is a CUDA_PATH in the environment variables - // but the needed variables are not in the PATH, it will return the missing - // environment variables - if (cudaVersion.toLowerCase().equals("nocuda")) { - pnLoad.append("p", "No CUDA distribution found.\n"); - parent.setGPUTf("CPU"); - } else if (!cudaVersion.contains(File.separator) && !cudaVersion.contains("---")) { - pnLoad.append("p", "Currently using CUDA " + cudaVersion); - } else if (!cudaVersion.contains(File.separator) && cudaVersion.contains("---")) { - // In linux several CUDA versions are allowed. These versions will be separated by "---" - String[] versions = cudaVersion.split("---"); - if (versions.length == 1) { - pnLoad.append("p", "Currently using CUDA " + versions[0]); - } else { - for (String str : versions) - pnLoad.append("p", "Found CUDA " + str); - } - } else if ((cudaVersion.contains("bin") || cudaVersion.contains("libnvvp"))) { - String[] outputs = cudaVersion.split(";"); - pnLoad.append("p", "Found CUDA distribution " + outputs[0] + ".\n"); - pnLoad.append("p", "Could not find environment variable:\n - " + outputs[1] + "\n"); - if (outputs.length == 3) - pnLoad.append("p", "Could not find environment variable:\n - " + outputs[2] + "\n"); - pnLoad.append("p", "Please add the missing environment variables to the path.\n"); - } - pnLoad.append("p", DeepLearningModel.PytorchCUDACompatibility(params.pytorchVersion, cudaVersion)); - pnLoad.append("h2", "Model info"); - pnLoad.append("p", "Path: " + params.selectedModelPath); - pnLoad.append("

Loading model..."); - - // Load the model using DJL - boolean isFiji = SystemUsage.checkFiji(); - // If the plugin is running on an IJ1 distribution, set the IJ classloader - // as the Thread ContextClassLoader - if (!isFiji) - Thread.currentThread().setContextClassLoader(IJ.getClassLoader()); - // TODO allow the use of translators and transforms - URL url; - // Block back button while loading - parent.setEnabledBackNext(false); - try { - url = new File(new File(params.path2Model).getAbsolutePath()).toURI().toURL(); - - if (params.selectedModelPath.equals("")) { - pnLoad.append("No Pytorch model found in the directory."); - parent.setEnabledBack(true); - } - String modelName = new File(params.selectedModelPath).getName(); - modelName = modelName.substring(0, modelName.indexOf(".pt")); - long startTime = System.nanoTime(); - Criteria criteria = Criteria.builder() - .setTypes(NDList.class, NDList.class) - .optModelUrls(url.toString()) // search models in specified path - .optModelName(modelName) - .optProgress(new ProgressBar()).build(); - - ZooModel model = ModelZoo.loadModel(criteria); - parent.getDeepPlugin().setTorchModel(model); - pnLoad.append(" -> Loaded!!!

"); - params.pytorchVersion = Engine.getInstance().getVersion(); - String lib = new File(getNativeLbraryFile()).getName(); - if (!lib.toLowerCase().contains("cpu")) { - pnLoad.append("p", "Model loaded on the GPU.\n"); - parent.setGPUPt("GPU"); - } else { - pnLoad.append("p", "Model loaded on the CPU.\n"); - parent.setGPUPt("CPU"); - } - String torchscriptSize = "" + new File(params.selectedModelPath).length() / (1024 * 1024.0); - torchscriptSize = torchscriptSize.substring(0, torchscriptSize.lastIndexOf(".") + 2); - long stopTime = System.nanoTime(); - // Convert nanoseconds into seconds - String loadingTime = "" + ((stopTime - startTime) / (float) 1000000000); - loadingTime = loadingTime.substring(0, loadingTime.lastIndexOf(".") + 3); - pnLoad.append("p", "Model size: " + torchscriptSize + " Mb"); - pnLoad.append("p", "Loading time: " + loadingTime + " s"); - - parent.setEnabledBackNext(true); - inpNumber.setEnabled(true); - outNumber.setEnabled(true); - } catch (MalformedURLException e) { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "Check that the path provided to the model remains the same."); - parent.setEnabledBack(true); - e.printStackTrace(); - } catch (EngineException e) { - String err = e.getMessage(); - String os = System.getProperty("os.name").toLowerCase(); - if (os.contains("win") && err.contains("https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md")) { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "Please install the Visual Studio 2019 redistributables and reboot\n" - + "your machine to be able to use Pytorch with DeepImageJ."); - pnLoad.append("p", "For more information:\n"); - pnLoad.append("p", " -https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md"); - pnLoad.append("p", " -https://github.com/awslabs/djl/issues/126"); - pnLoad.append("p", "If you already have installed VS2019 redistributables, the error\n" - + "might be caused by a missing dependency or an incompatible Pytorch version."); - pnLoad.append("p", "Furthermore, the DJL Pytorch dependencies (pytorch-egine, pytorch-api and pytorch-native-auto) " - + "should be compatible with each other."); - pnLoad.append("p", "Please check the DeepImageJ Wiki."); - } else if((os.contains("linux") || os.contains("unix")) && err.contains("https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md")){ - pnLoad.append("p", "DeepImageJ could not load the model."); - pnLoad.append("p", "Check that there are no repeated dependencies on the jars folder."); - pnLoad.append("p", "The problem might be caused by a missing or repeated dependency or an incompatible Pytorch version."); - pnLoad.append("p", "Furthermore, the DJL Pytorch dependencies (pytorch-egine, pytorch-api and pytorch-native-auto) " - + "should be compatible with each other."); - pnLoad.append("p", "If the problem persists, please check the DeepImageJ Wiki."); - } else { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "Either the DJL Pytorch version is incompatible with the Torchscript model's " - + "Pytorch version or the DJL Pytorch dependencies (pytorch-egine, pytorch-api and pytorch-native-auto) " + - "are not compatible with each other."); - pnLoad.append("p", "Please check the DeepImageJ Wiki."); - } - parent.setEnabledBack(true); - e.printStackTrace(); - } catch (ModelNotFoundException e) { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "No model was found in the path provided."); - parent.setEnabledBack(true); - e.printStackTrace(); - } catch (MalformedModelException e) { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "The model provided is not a correct Torchscript model."); - parent.setEnabledBack(true); - e.printStackTrace(); - } catch (IOException e) { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "Error whie accessing the model file."); - parent.setEnabledBack(true); - e.printStackTrace(); - } catch (Exception e) { - pnLoad.append("p", "DeepImageJ could not load the model"); - pnLoad.append("p", "Error whie accessing the model file."); - parent.setEnabledBack(true); - e.printStackTrace(); - } - } - - /* - * Find the Pytorch model (".pt" or ".pth") inside the folder provided. - * If there are more than one model, make the user decide. - */ - private String findPytorchModels(String modelPath) { - - File[] folderFiles = new File(modelPath).listFiles(); - ArrayList ptModels = new ArrayList(); - for (File file : folderFiles) { - if (file.getName().contains(".pt")) - ptModels.add(file); - } - - if (ptModels.size() == 1) - return ptModels.get(0).getAbsolutePath(); - - GenericDialog dlg = new GenericDialog("Choose Pytorch model"); - dlg.addMessage("The folder provided contained several Pytorch models"); - dlg.addMessage("Select which do you want to load."); - String[] fileArray = new String[ptModels.size()]; - int c = 0; - for (File f : ptModels) - fileArray[c ++] = f.getName(); - dlg.addChoice("Select framework", fileArray, fileArray[0]); - dlg.showDialog(); - if (dlg.wasCanceled()) { - dlg.dispose(); - return ""; - } - return modelPath + File.separator + dlg.getNextChoice(); - } - - /* - * Method to find the native libraey loaded by DJL to use Pytorch - */ - public static String getNativeLbraryFile() { - String nativeLibrary = "???"; - Field field; - try { - field = ClassLoader.class.getDeclaredField("loadedLibraryNames"); - field.setAccessible(true); - Object libraries = field.get(null); - if (libraries instanceof Vector) { - for (String ll : (Vector) libraries) { - if (ll.contains("torch_")){ - nativeLibrary = ll; - break; - } - } - } else if (libraries instanceof String[]) { - for (String ll : (String[]) libraries) { - if (ll.contains("torch_")){ - nativeLibrary = ll; - break; - } - } - } - } catch (Exception e) { - e.printStackTrace(); - } - return nativeLibrary; - } -} diff --git a/src/main/java/deepimagej/stamp/LoadTFStamp.java b/src/main/java/deepimagej/stamp/LoadTFStamp.java deleted file mode 100755 index 705c6394..00000000 --- a/src/main/java/deepimagej/stamp/LoadTFStamp.java +++ /dev/null @@ -1,408 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; -import java.awt.BorderLayout; -import java.io.File; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Set; - -import javax.swing.BoxLayout; -import javax.swing.JComboBox; -import javax.swing.JPanel; - -import org.tensorflow.SavedModelBundle; -import org.tensorflow.framework.SignatureDef; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.DeepLearningModel; -import deepimagej.components.HTMLPane; -import deepimagej.tools.DijTensor; -import deepimagej.tools.FileTools; -import deepimagej.tools.Log; -import deepimagej.tools.StartTensorflowService; -import deepimagej.tools.SystemUsage; -import ij.IJ; - -public class LoadTFStamp extends AbstractStamp implements Runnable { - - private ArrayList tags; - private JComboBox cmbTags = new JComboBox(); - private JComboBox cmbGraphs = new JComboBox(); - //private ArrayList architecture = new ArrayList(); - private String name; - - private HTMLPane pnLoad; - - - public LoadTFStamp(BuildDialog parent) { - super(parent); - tags = new ArrayList(); - tags.add("Serve"); - buildPanel(); - } - - @Override - public void buildPanel() { - pnLoad = new HTMLPane(Constants.width, 70); - - HTMLPane pnTag = new HTMLPane(Constants.width / 2, 70); - pnTag.append("h2", "Model Tag"); - pnTag.append("p", "Tag used to save the TensorFlow SavedModel. If the plugin cannot automatically find it, you will need to edit it."); - - HTMLPane pnGraph = new HTMLPane(2 * Constants.width / 2, 70); - pnGraph.append("h2", "SignatureDef"); - pnGraph.append("p", "SignatureDef used to call the wanted model graph. There might be more than one in the same model folder."); - - JPanel pn = new JPanel(); - pn.setLayout(new BoxLayout(pn, BoxLayout.PAGE_AXIS)); - pn.add(pnTag.getPane()); - pn.add(cmbTags); - pn.add(pnGraph.getPane()); - pn.add(cmbGraphs); - JPanel main = new JPanel(new BorderLayout()); - main.add(pnLoad.getPane(), BorderLayout.CENTER); - main.add(pn, BorderLayout.SOUTH); - panel.add(main); - } - - @Override - public void init() { - Thread thread = new Thread(this); - thread.setPriority(Thread.MIN_PRIORITY); - thread.start(); - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - if (params.tag == null) { - Log log = new Log(); - String tag = (String)cmbTags.getSelectedItem(); - try { - double time = System.nanoTime(); - SavedModelBundle model = DeepLearningModel.loadTf(params.path2Model, tag, log); - time = System.nanoTime() - time; - addLoadInfo(params, time); - parent.getDeepPlugin().setTfModel(model); - params.tag = tag; - cmbTags.setEditable(false); - parent.getDeepPlugin().setTfModel(model); - params.graphSet = DeepLearningModel.metaGraphsSet(model); - if (params.graphSet.size() > 0) { - Set tfGraphSet = DeepLearningModel.returnTfSig(params.graphSet); - cmbGraphs.removeAllItems(); - for (int i = 0; i < params.graphSet.size(); i++) { - cmbGraphs.addItem((String) tfGraphSet.toArray()[i]); - cmbGraphs.setEditable(false); - } - } - - } - catch (Exception e) { - IJ.error("Incorrect ModelTag"); - params.tag = null; - cmbTags.removeAllItems(); - cmbTags.setEditable(true); - } - return false; - } else { - // TODO put it inside run - SavedModelBundle model = parent.getDeepPlugin().getTfModel(); - params.graph = DeepLearningModel.returnStringSig((String)cmbGraphs.getSelectedItem()); - SignatureDef sig = DeepLearningModel.getSignatureFromGraph(model, params.graph); - params.totalInputList = new ArrayList<>(); - params.totalOutputList = new ArrayList<>(); - String[] inputs = DeepLearningModel.returnTfInputs(sig); - String[] outputs = DeepLearningModel.returnTfOutputs(sig); - Arrays.sort(inputs); - Arrays.sort(outputs); - pnLoad.append("p", "Number of outputs: " + outputs.length); - boolean valid = true; - try { - for (int i = 0; i < inputs.length; i ++) { - DijTensor inp = new DijTensor(inputs[i]); - inp.setInDimensions(DeepLearningModel.modelTfEntryDimensions(sig, inputs[i])); - params.totalInputList.add(inp); - } - for (int i = 0; i < outputs.length; i ++) { - DijTensor out = new DijTensor(outputs[i]); - out.setInDimensions(DeepLearningModel.modelTfExitDimensions(sig, outputs[i])); - params.totalOutputList.add(out); - } - } - catch (Exception ex) { - pnLoad.append("p", "Dimension: ERROR"); - valid = false; - parent.setEnabledBackNext(valid); - return false; - } - parent.setEnabledBackNext(valid); - return true; - } - } - - // TODO separate in methods - public void run() { - parent.setEnabledBack(false); - parent.setEnabledNext(false); - pnLoad.setCaretPosition(0); - pnLoad.setText(""); - pnLoad.append("p", "Loading available Tensorflow version."); - String loadInfo = "ImageJ"; - boolean isFiji = SystemUsage.checkFiji(); - if (isFiji) - loadInfo = StartTensorflowService.loadTfLibrary(); - - // If loadlLibrary() returns 'ImageJ', the plugin is running - // on an ImageJ1 instance - parent.setFiji(!loadInfo.contains("ImageJ")); - pnLoad.setCaretPosition(0); - pnLoad.setText(""); - if (loadInfo.equals("")) { - pnLoad.append("p", "Unable to find any Tensorflow distribution."); - pnLoad.append("p", "Please, install a valid Tensorflow version."); - parent.setEnabledBack(true); - return; - } - - Parameters params = parent.getDeepPlugin().params; - cmbTags.removeAllItems(); - cmbGraphs.removeAllItems(); - String tfVersion = DeepLearningModel.getTFVersion(parent.getFiji()); - pnLoad.clear(); - pnLoad.append("h2", "Tensorflow version"); - if (loadInfo.toLowerCase().contains("gpu")) - tfVersion += "_GPU"; - pnLoad.append("p", "Currently using Tensorflow " + tfVersion); - if (parent.getFiji()) { - pnLoad.append("p", loadInfo); - } else { - pnLoad.append("p", "To change the Tensorflow version, download the corresponding\n" - + "libtensorflow and libtensorflow_jni jars and copy them into\n" - + "the plugins folder."); - } - // Run the nvidia-smi to see if it is possible to locate a GPU - String cudaVersion = ""; - if (tfVersion.contains("GPU")) - cudaVersion = SystemUsage.getCUDAEnvVariables(); - else - parent.setGPUTf("CPU"); - // If a CUDA distribution was found, cudaVersion will be equal - // to the CUDA version. If not it can be either 'noCuda', if CUDA - // is not installed, or if there is a CUDA_PATH in the environment variables - // but the needed variables are not in the PATH, it will return the missing - // environment variables - if (tfVersion.contains("GPU") && cudaVersion.equals("nocuda")) { - pnLoad.append("p", "No CUDA distribution found.\n"); - parent.setGPUTf("CPU"); - } else if (tfVersion.contains("GPU") && !cudaVersion.contains(File.separator) && !cudaVersion.contains("---")) { - pnLoad.append("p", "Currently using CUDA " + cudaVersion); - pnLoad.append("p", DeepLearningModel.TensorflowCUDACompatibility(tfVersion, cudaVersion)); - } else if (tfVersion.contains("GPU") && !cudaVersion.contains(File.separator) && cudaVersion.contains("---")) { - // In linux several CUDA versions are allowed. These versions will be separated by "---" - String[] versions = cudaVersion.split("---"); - if (versions.length == 1) { - pnLoad.append("p", "Currently using CUDA " + versions[0]); - } else { - for (String str : versions) - pnLoad.append("p", "Found CUDA " + str); - } - pnLoad.append("p", DeepLearningModel.TensorflowCUDACompatibility(tfVersion, cudaVersion)); - } else if (tfVersion.contains("GPU") && (cudaVersion.contains("bin") || cudaVersion.contains("libnvvp"))) { - pnLoad.append("p", DeepLearningModel.TensorflowCUDACompatibility(tfVersion, cudaVersion)); - String[] outputs = cudaVersion.split(";"); - pnLoad.append("p", "Found CUDA distribution " + outputs[0] + ".\n"); - pnLoad.append("p", "Could not find environment variable:\n - " + outputs[1] + "\n"); - if (outputs.length == 3) - pnLoad.append("p", "Could not find environment variable:\n - " + outputs[2] + "\n"); - pnLoad.append("p", "Please add the missing environment variables to the path.\n"); - } - - pnLoad.append("h2", "Model info"); - File file = new File(params.path2Model); - if (file.exists()) - name = file.getName(); - - pnLoad.append("h2", "Load " + name); - - String pnTxt = pnLoad.getText(); - Log log = new Log(); - params.tag = null; - - // Block back button while loading - parent.setEnabledBackNext(false); - Object[] info = null; - double time = -1; - pnLoad.append("

Loading model..."); - ArrayList initialSmi = null; - ArrayList finalSmi = null; - try { - if (tfVersion.contains("GPU") && parent.getGPUTf().equals("")) - initialSmi = SystemUsage.runNvidiaSmi(); - double chrono = System.nanoTime(); - info = DeepLearningModel.findTfTag(params.path2Model); - time = System.nanoTime() - chrono; - if (tfVersion.contains("GPU") && parent.getGPUTf().equals("")) - finalSmi = SystemUsage.runNvidiaSmi(); - } catch (Exception ex) { - pnLoad.clear(); - pnLoad.setText(pnTxt); - ex.printStackTrace(); - IJ.error("DeepImageJ could not load the model," - + "try with another Tensorflow version"); - pnLoad.append("h2", "DeepImageJ could not load the model.\n"); - pnLoad.append("h2", "Try with another Tensorflow version.\n"); - // Let the developer go back, but no forward - parent.setEnabledBack(true); - parent.setEnabledNext(false); - return; - } - pnLoad.append(" -> Loaded!!!

"); - - // Check if the model has been loaded on GPU - if (tfVersion.contains("GPU") && !parent.getGPUTf().equals("GPU")) { - String GPUInfo = SystemUsage.isUsingGPU(initialSmi, finalSmi); - // TODO if the CUDA version is not compatible with the TF version, - // it is impossible to load the model on GPU - if (GPUInfo.equals("noImageJProcess") && !cudaVersion.contains(File.separator)) { - pnLoad.append("p", "Unable to run nvidia-smi to check if the model was loaded on a GPU.\n"); - parent.setGPUTf("???"); - } else if (GPUInfo.equals("noImageJProcess")) { - pnLoad.append("p", "Unable to load model on GPU.\n"); - parent.setGPUTf("CPU"); - } else if(GPUInfo.equals("¡RepeatedImageJGPU!")) { - int nImageJInstances = SystemUsage.numberOfImageJInstances(); - // Get number of IJ instances using GPU - int nGPUIJInstances = GPUInfo.split("¡RepeatedImageJGPU!").length; - if (nImageJInstances > nGPUIJInstances) { - pnLoad.append("p", "Found " + nGPUIJInstances + "instances of ImageJ/Fiji using GPU" - + " out of the " + nImageJInstances + " opened.\n"); - pnLoad.append("p", "Could not assert that the model was loaded on the GPU.\n"); - parent.setGPUTf("???"); - } else if (nImageJInstances <= nGPUIJInstances) { - pnLoad.append("p", "Model loaded on the GPU.\n"); - if (cudaVersion.contains("bin") || cudaVersion.contains("libnvvp")) - pnLoad.append("p", "Note that with missing environment variables, GPU performance might not be optimal.\n"); - parent.setGPUTf("GPU"); - } - } else { - pnLoad.append("p", "Model loaded on the GPU.\n"); - if (cudaVersion.contains("bin") || cudaVersion.contains("libnvvp")) - pnLoad.append("p", "Note that due to missing environment variables, GPU performance might not be optimal.\n"); - parent.setGPUTf("GPU"); - } - } else if (tfVersion.contains("GPU")) { - pnLoad.append("p", "Model loaded on the GPU.\n"); - if (cudaVersion.contains("bin") || cudaVersion.contains("libnvvp")) - pnLoad.append("p", "Note that due to missing environment variables, GPU performance might not be optimal.\n"); - } - - String tag = (String) info[0]; - if (tag != null) { - params.tag = tag; - String tfTag = DeepLearningModel.returnTfTag(tag); - cmbTags.addItem(tfTag); - cmbTags.setEditable(false); - SavedModelBundle model = null; - if (!(info[2] instanceof SavedModelBundle)) { - model = DeepLearningModel.loadTf(params.path2Model, params.tag, log); - } else { - model = (SavedModelBundle) info[2]; - addLoadInfo(params, time); - } - parent.getDeepPlugin().setTfModel(model); - try { - params.graphSet = DeepLearningModel.metaGraphsSet(model); - } catch (Exception ex) { - ex.printStackTrace(); - IJ.error("DeepImageJ could not load the model,\n" - + "try with another Tensorflow version"); - pnLoad.append("h2", "DeepImageJ could not load the model.\n"); - pnLoad.append("h2", "Try with another Tensorflow version.\n"); - // Let the developer go back, but no forward - parent.setEnabledBack(true); - parent.setEnabledNext(false); - return; - } - if (params.graphSet.size() > 0) { - Set tfGraphSet = DeepLearningModel.returnTfSig(params.graphSet); - for (int i = 0; i < params.graphSet.size(); i++) { - cmbGraphs.addItem((String) tfGraphSet.toArray()[i]); - cmbGraphs.setEditable(false); - } - } - } else { - cmbTags.addItem(""); - cmbTags.setEditable(true); - cmbGraphs.addItem(""); - cmbGraphs.setEditable(false); - pnLoad.append("p", "The plugin could not load the model automatically,
" - + "please introduce the needed information to load the model."); - } - // If we loaded either a Bioimage Zoo or Tensoflow model we continue - parent.setEnabledBackNext(true); - } - - /* - * Check if the classes - */ - - /* - * Add load information to the panel - */ - private void addLoadInfo(Parameters params, double time) { - pnLoad.append("p", "Path to model: " + params.path2Model + "\n"); - String timeStr = (time / 1000000000) + ""; - timeStr = timeStr.substring(0, timeStr.lastIndexOf(".") + 3); - pnLoad.append("p", "Time to load model: " + timeStr + " s\n"); - String modelSize = "" + FileTools.getFolderSize(params.path2Model + File.separator + "variables") / (1024*1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - pnLoad.append("p", "Size of the weights: " + modelSize + " MB"); - - } -} diff --git a/src/main/java/deepimagej/stamp/OutputDimensionStamp.java b/src/main/java/deepimagej/stamp/OutputDimensionStamp.java deleted file mode 100755 index 07c18f21..00000000 --- a/src/main/java/deepimagej/stamp/OutputDimensionStamp.java +++ /dev/null @@ -1,777 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Container; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.GridLayout; -import java.awt.Insets; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import javax.swing.BorderFactory; -import javax.swing.BoxLayout; -import javax.swing.JButton; -import javax.swing.JComboBox; -import javax.swing.JFrame; -import javax.swing.JLabel; -import javax.swing.JPanel; -import javax.swing.JTextField; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.components.GridPanel; -import deepimagej.components.HTMLPane; -import deepimagej.tools.DijTensor; -import deepimagej.tools.Index; -import ij.IJ; - -public class OutputDimensionStamp extends AbstractStamp implements ActionListener { - - private static List firstRowList = new ArrayList(); - private static List secondRowList = new ArrayList(); - private static List thirdRowList = new ArrayList(); - - private static List> cmbRowList = new ArrayList>(); - - private static GridPanel pnOutputInfo = new GridPanel(true); - private static GridPanel firstRow = new GridPanel(true); - private static GridPanel secondRow = new GridPanel(true); - private static GridPanel thirdRow = new GridPanel(true); - private static JLabel firstLabel = new JLabel("Scaling factor"); - private static JLabel secondLabel = new JLabel("Halo factor"); - private static JLabel thirdLabel = new JLabel("Offset factor"); - private static JPanel pn = new JPanel(); - - private static JComboBoxreferenceImage = new JComboBox(new String[] {"aux"}); - private static JLabel refLabel = new JLabel("Reference input image"); - private static JLabel lblName = new JLabel("Name"); - - private static JLabel lblType = new JLabel("Output type: "); - - - private static JButton bnNextOutput = new JButton("Next Output"); - private static JButton bnPrevOutput = new JButton("Previous Output"); - - private static int outputCounter = 0; - private String model = ""; - - private static double[] rangeOptions = {Double.NEGATIVE_INFINITY, (double) -1, (double) 0, (double) 1, Double.POSITIVE_INFINITY}; - - private static JComboBox cmbRangeLow = new JComboBox(new String [] {"-inf", "-1", "0", "1", "inf"}); - private static JComboBox cmbRangeHigh = new JComboBox(new String [] {"-inf", "-1", "0", "1", "inf"}); - - // Parameter to observe if there has been any change in form or tensor type - private static List savedInputs = null; - - - public OutputDimensionStamp(BuildDialog parent) { - super(parent); - buildPanel(); - cmbRangeHigh.setSelectedIndex(4); - } - - @Override - public void buildPanel() { - - HTMLPane info = new HTMLPane(Constants.width, 150); - info.append("h2", "Model output size constraints"); - info.append("p", "model output size = model input size * scale + 2 * offset"); - info.append("p", "valid model output size -> model input size * scale + 2 * offset - 2*halo > 0"); - info.append("
    "); - info.append("li", "

    Scaling factor: the factor by which the output image " - + "dimensions are rescaled. E.g. in superresolution, if the output size " - + "is twice the size of the input, the scaling factor should be [2,2]. See the equation." - + "Scaling factor can only be positive integers for axes XYZ. In the channels axis" - + "the scaling factor can also be 0

    "); - info.append("li", "

    Offset factor: Difference between the input and output size. Note that " - + "this is different from a scaling factor. See the equation. For axes XYZ the " - + "offset does not affect the size of the final reconstructed image, it will be the" - + "as the orginal one, only applying the corresponding scaling. However, the offset" - + "of the C axis will does affect the size of the reconstructed image.

    "); - info.append("li", "

    Halo facto: Size of the receptive field of one pixel in the " - + "network used to avoid artifacts along the borders of the image. If the " - + "convolutions inside the network do not use padding, set this value to 0." - + "Halo cannot be a negative integer.

    "); - info.append("
"); - info.append("h2", "Final reconstructed image size"); - info.append("p", "The constraints defined above apply differently for" - + " the raw output of the model and for the final reconstructed image" - + "depending on the axis. For the final reconstructed image the following apply:"); - info.append("
    "); - info.append("li", "

    Final reconstructed image (X) = input image (X) * scaling factor (X)

    "); - info.append("li", "

    Final reconstructed image (Y) = input image (Y) * scaling factor (Y)

    "); - info.append("li", "

    Final reconstructed image (Z) = input image (Z) * scaling factor (Z)

    "); - info.append("li", "

    Final reconstructed image (C) = input image (C) * scaling factor (C) + 2 * offset factor (X)

    "); - info.append("
"); - - firstRow.place(0, 0, firstLabel); - firstRow.place(0, 1, new JLabel("aux")); firstRow.place(1, 1, new JTextField("aux")); - secondRow.place(0, 0, secondLabel); - secondRow.place(0, 1, new JLabel("aux")); secondRow.place(1, 1, new JTextField("aux")); - thirdRow.place(0, 0, thirdLabel); - thirdRow.place(0, 1, new JLabel("aux")); thirdRow.place(1, 1, new JTextField("aux")); - - JFrame pnFr = new JFrame(); - Container cn = pnFr.getContentPane(); - cn.setLayout(new GridBagLayout()); - - GridBagConstraints labelC = new GridBagConstraints(); - labelC.gridwidth = 2; - labelC.gridheight = 1; - labelC.gridx = 0; - labelC.gridy = 0; - labelC.ipadx = 0; - labelC.weightx = 1; - labelC.insets = new Insets(10, 10, 0, 5); - - // Set the output name - lblName.setText("Output name: NAME"); - cn.add(lblName, labelC); - // Set the output name - labelC.insets = new Insets(5, 10, 0, 5); - labelC.gridy = 1; - cn.add(lblType, labelC); - - // Set the reference image - labelC.gridy = 2; - cn.add(refLabel, labelC); - labelC.gridwidth = 1; - labelC.ipadx = 2; - labelC.gridy = 2; - labelC.gridx = 2; - labelC.weightx = 0.2; - labelC.insets = new Insets(0, 10, 0, 5); - cn.add(referenceImage, labelC); - - - labelC.gridwidth = 1; - labelC.gridheight = 1; - labelC.gridx = 0; - labelC.gridy = 3; - labelC.ipadx = 5; - labelC.weightx = 0.1; - labelC.insets = new Insets(10, 20, 5, 5); - - GridBagConstraints infoC = new GridBagConstraints(); - infoC.gridwidth = 10; - infoC.gridheight = 3; - infoC.gridx = 0; - infoC.gridy = 3; - infoC.ipadx = 15; - infoC.ipady = 10; - infoC.weightx = 0.9; - infoC.anchor = GridBagConstraints.CENTER; - infoC.fill = GridBagConstraints.BOTH; - infoC.insets = new Insets(5, 20, 5, 20); - - // First field - cn.add(firstRow, infoC); - - // Second field - labelC.gridy = 6; - infoC.gridy = 6; - cn.add(secondRow, infoC); - - // Third field - labelC.gridy = 9; - infoC.gridy = 9; - cn.add(thirdRow, infoC); - - // Data range combo boxes - labelC.gridy = 13; - infoC.gridy = 13; - infoC.gridx = 2; - infoC.gridwidth = 1; - infoC.ipadx = 5; - infoC.ipady = 2; - labelC.insets = new Insets(5, 20, 5, 5); - infoC.insets = new Insets(5, 5, 5, 15); - cn.add(new JLabel("Data Range: lower bound"), labelC); - cn.add(cmbRangeLow, infoC); - - labelC.gridx = 6; - labelC.insets = new Insets(5, 15, 5, 5); - infoC.insets = new Insets(5, 5, 5, 20); - cn.add(new JLabel("Data Range: upper bound"), labelC); - infoC.gridx = 8; - cn.add(cmbRangeHigh, infoC); - - pn = new JPanel(); - pn.setLayout(new BoxLayout(pn, BoxLayout.PAGE_AXIS)); - pn.add(info.getPane()); - pn.add(cn); - - - JPanel pnButtons = new JPanel(new GridLayout(1, 2)); - pnButtons.setBorder(BorderFactory.createEtchedBorder()); - pnButtons.add(bnPrevOutput); - pnButtons.add(bnNextOutput); - pn.add(pnButtons, BorderLayout.SOUTH); - - panel.add(pn); - - bnNextOutput.addActionListener(this); - bnPrevOutput.addActionListener(this); - outputCounter = 0; - - } - - @Override - public void init() { - Parameters params = parent.getDeepPlugin().params; - // Set the screen at the first input if the model changes - String modelOfInterest = params.path2Model; - if (!modelOfInterest.equals(model)) { - model = modelOfInterest; - savedInputs = DijTensor.copyTensorList(params.outputList); - outputCounter = 0; - } else if (params.outputList.size() != savedInputs.size()) { - savedInputs = DijTensor.copyTensorList(params.outputList); - outputCounter = 0; - } else { - // Check if any output tensor definition has changed, an if it has - // put the attention on it - boolean changed = false; - for (int i = 0; i < params.outputList.size(); i ++) { - boolean sameTensor = params.outputList.get(i).name.equals(savedInputs.get(i).name); - boolean sameForm = params.outputList.get(i).form.equals(savedInputs.get(i).form); - boolean sameType = params.outputList.get(i).tensorType.equals(savedInputs.get(i).tensorType) ; - if (!sameTensor || !sameType || !sameForm) { - // Ask the user to repeat the tensor changed - params.outputList.get(i).finished = false; - // Start at the first tensor changed - if (!changed) { - savedInputs = DijTensor.copyTensorList(params.outputList); - outputCounter = i; - changed = true; - } - } - - } - } - bnNextOutput.setEnabled(true); - bnPrevOutput.setEnabled(true); - referenceImage.removeAllItems(); - for (DijTensor in : params.inputList) { - if (in.tensorType.contains("image")) - referenceImage.addItem(in.name); - } - updateInterface(params); - - - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - saveOutputData(params); - for (DijTensor tensor : params.outputList) { - if (!tensor.finished){ - IJ.error("You need to fill information for every input tensor"); - return false; - } - } - - return true; - } - - public static void updateInterface(Parameters params) { - - // Check how many outputs there are to enable or not - // the "next" and "back" buttons - if (outputCounter == 0) { - bnPrevOutput.setEnabled(false); - } else { - bnPrevOutput.setEnabled(true); - } - if (outputCounter < (params.outputList.size() - 1)) { - bnNextOutput.setEnabled(true); - } else { - bnNextOutput.setEnabled(false); - } - - lblName.setText("Output name: " + params.outputList.get(outputCounter).name); - lblType.setText("Output type: " + params.outputList.get(outputCounter).tensorType); - // Reinitialise all the params - pnOutputInfo.removeAll(); firstRow.removeAll(); secondRow.removeAll(); thirdRow.removeAll(); - firstRowList = new ArrayList(); secondRowList = new ArrayList(); thirdRowList = new ArrayList(); - pn.remove(0); - if (params.outputList.get(outputCounter).tensorType.contains("image") && !params.pyramidalNetwork) { - // Build panel for image - writeInfoText("image"); - getPanelForImage(params); - } else if (params.outputList.get(outputCounter).tensorType.contains("image") && params.pyramidalNetwork) { - // Build panel for pyramidal net - writeInfoText("pyramidalImage"); - getPanelForImagePyramidalNet(params); - }else if (params.outputList.get(outputCounter).tensorType.contains("list")) { - // Build panel for list - writeInfoText("list"); - getPanelForList(params); - } else { - outputCounter ++; - } - pnOutputInfo.revalidate(); - pnOutputInfo.repaint(); - - } - - /* - * Method to retrieve from the UI the information necessary to build a - * list from the tensor outputed by the model - */ - public static boolean saveOutputDataForList(Parameters params) { - // There is no offset or halo in the case of the outùt being a list. - // There is also no scale, but for convenience we will set it to 1. - DijTensor tensor = params.outputList.get(outputCounter); - tensor.scale = new float[tensor.tensor_shape.length]; - tensor.halo = new int[tensor.tensor_shape.length]; - tensor.offset = new float[tensor.tensor_shape.length]; - // Set the scale equal to 1 for every dimension - for (int i = 0; i < tensor.scale.length; i ++) - tensor.scale[i] = 1; - // Now do the important thing in this step. Change the dimension letters - // by C if it correspond to the column, or R if it corresponds to row - int batchInd = DijTensor.getBatchInd(tensor.form); - // Form containing rows and cols - String newForm = ""; - - int cmbCount = 0; - for (int i = 0; i < params.outputList.get(outputCounter).scale.length; i++) { - if (i == batchInd) { - newForm = newForm + "B"; - } else { - String selectedItem = String.valueOf(cmbRowList.get(cmbCount).getSelectedItem()); - String letter = selectedItem.split("")[0]; - cmbCount ++; - if (newForm.indexOf(letter) == -1) { - newForm = newForm + letter; - } else { - IJ.error("You cannot select the same field in both combo boxes."); - return false; - } - } - } - tensor.form = newForm; - return true; - } - - /* - * Method to retrieve from the UI the information necessary to build an - * image from the tensor outputed by the model - */ - public static boolean saveOutputDataForImage(Parameters params) { - // Save all the information for the output given by the variable 'outputInd' - String ref = (String) referenceImage.getSelectedItem(); - params.outputList.get(outputCounter).referenceImage = ref; - - params.outputList.get(outputCounter).scale = new float[params.outputList.get(outputCounter).tensor_shape.length]; - params.outputList.get(outputCounter).halo = new int[params.outputList.get(outputCounter).tensor_shape.length]; - params.outputList.get(outputCounter).offset = new float[params.outputList.get(outputCounter).tensor_shape.length]; - int batchInd = DijTensor.getBatchInd(params.outputList.get(outputCounter).form); - - int textFieldInd = 0; - for (int i = 0; i < params.outputList.get(outputCounter).scale.length; i++) { - try { - float scaleValue = 1; int haloValue = 0; float offsetValue = 0; - if (i == batchInd) { - params.outputList.get(outputCounter).scale[i] = 1; - params.outputList.get(outputCounter).halo[i] = 0; - params.outputList.get(outputCounter).offset[i] = 0; - } else { - // If the value for scale is "-" because there is no dimension in the reference image, - // save it as -1 - scaleValue = Float.valueOf(firstRowList.get(textFieldInd).isEditable() ? firstRowList.get(textFieldInd).getText() : "-1"); - params.outputList.get(outputCounter).scale[i] = scaleValue; - haloValue = Integer.valueOf(secondRowList.get(textFieldInd).getText()); - params.outputList.get(outputCounter).halo[i] = haloValue; - // If the value for offset is "-" because there is no dimension in the reference image, - // save it as 0 because offset can be negative - offsetValue = Float.parseFloat(thirdRowList.get(textFieldInd).isEditable() ? thirdRowList.get(textFieldInd).getText() : "0"); - // TODO if the offset is positive for X and Y dimensions, open an error saying that this is not supported yet - // TODO decide how to robustly manage offsets - String currentDim = params.outputList.get(outputCounter).form.split("")[i].toLowerCase(); - if (offsetValue > 0 && !currentDim.toLowerCase().equals("c")) { - IJ.error("Positive offset values are not\n" - + "supported yet for dimensions X and Y."); - return false; - } else if (offsetValue % 0.5 != 0) { - IJ.error("The offset should be a multiple of 0.5."); - return false; - } - // Do not allow 0 scaling factors in the XYZ dimensions - params.outputList.get(outputCounter).offset[i] = offsetValue; - if (scaleValue == 0 && !currentDim.toLowerCase().equals("c")) { - IJ.error("A 0 scaling factor is not allowed for X, Y or Z dimensions"); - return false; - } - // Do not allow negative scaling values - if (scaleValue < 0) { - IJ.error("Scaling factors can only be positive integers."); - return false; - } - // Do not allow negative halo values - if (haloValue < 0) { - IJ.error("Halo factors can only be positive."); - return false; - } - - textFieldInd++; - } - } catch( NumberFormatException ex) { - IJ.error("Make sure that no text field is empty and\n" - + "that they correspond to real numbers."); - return false; - } - } - return true; - } - - /* - * Method to retrieve from the UI the information necessary to build an - * image from the tensor outputed by the model in the case the model has - * a Pyramidal structure - */ - public static boolean saveOutputDataForImagePyramidalNet(Parameters params) { - // Save all the information for the output given by the variable 'outputInd' - // Get the reference tensor - - params.outputList.get(outputCounter).sizeOutputPyramid = new int[params.outputList.get(outputCounter).tensor_shape.length]; - int batchInd = DijTensor.getBatchInd(params.outputList.get(outputCounter).form); - - int textFieldInd = 0; - for (int i = 0; i < params.outputList.get(outputCounter).sizeOutputPyramid.length; i++) { - try { - int sizeOutputPyramid = 1; - if (i == batchInd) { - params.outputList.get(outputCounter).sizeOutputPyramid[i] = 1; - } else { - sizeOutputPyramid = Integer.valueOf(firstRowList.get(textFieldInd ++).getText()); - params.outputList.get(outputCounter).sizeOutputPyramid[i] = sizeOutputPyramid; - } - } catch( NumberFormatException ex) { - IJ.error("Make sure that no text field is empty and\n" - + "that they correspond to real numbers."); - return false; - } - } - return true; - } - - /* - * Method to retrieve from the UI the information necessary to build - * whatever object is needed for the output tensor - */ - public static boolean saveOutputData(Parameters params) { - // If the methods saving the info were successful, wasSaved=true - boolean wasSaved = false; - if (params.outputList.get(outputCounter).tensorType.contains("image") && !params.pyramidalNetwork) { - wasSaved = saveOutputDataForImage(params); - } else if (params.outputList.get(outputCounter).tensorType.contains("image") && params.pyramidalNetwork) { - wasSaved = saveOutputDataForImagePyramidalNet(params); - } else { - wasSaved = saveOutputDataForList(params); - } - - int lowInd = cmbRangeLow.getSelectedIndex(); - int highInd = cmbRangeHigh.getSelectedIndex(); - if (lowInd >= highInd) { - IJ.error("The Data Range has to go from a value to a higher one."); - return false; - } - - params.outputList.get(outputCounter).dataRange[0] = rangeOptions[lowInd]; - params.outputList.get(outputCounter).dataRange[1] = rangeOptions[highInd]; - params.outputList.get(outputCounter).finished = wasSaved; - - //completeInfo[outputCounter] = wasSaved; - return wasSaved; - } - - private static void getPanelForImagePyramidalNet(Parameters params) { - - int[] dimValues = DijTensor.getWorkingDimValues(params.outputList.get(outputCounter).form, params.outputList.get(outputCounter).tensor_shape); - String[] dims = DijTensor.getWorkingDims(params.outputList.get(outputCounter).form); - - - firstLabel.setText("Output size"); - firstLabel.setVisible(true); - firstRow.place(0, 0, firstLabel); - - for (int i = 0; i < dimValues.length; i ++) { - JLabel dimLetter1 = new JLabel(dims[i].toLowerCase().contains("z") ? "N/i/z" : dims[i]); - JTextField txt1; - - int auxInd = params.outputList.get(outputCounter).form.indexOf(dims[i]); - - txt1 = new JTextField(params.outputList.get(outputCounter).finished ? "" + params.outputList.get(outputCounter).sizeOutputPyramid[auxInd] : "1", 5); - txt1.setEditable(true); - if (dimValues[i] != -1) { - txt1.setText("" + dimValues[i]); - txt1.setEditable(false); - } else if (dimValues[i] == -1) { - txt1.setText("" + 0); - txt1.setEditable(true); - } - - firstRow.place(0, i + 1, dimLetter1); - firstRow.place(1, i + 1, txt1); - - firstRowList.add(txt1); - } - - secondRow.setVisible(false); - thirdRow.setVisible(false); - - refLabel.setVisible(false); - referenceImage.setVisible(false); - - firstRow.revalidate(); - firstRow.repaint(); - - } - - private static void getPanelForImage(Parameters params) { - - int[] dimValues = DijTensor.getWorkingDimValues(params.outputList.get(outputCounter).form, params.outputList.get(outputCounter).tensor_shape); - String[] dims = DijTensor.getWorkingDims(params.outputList.get(outputCounter).form); - // Get the reference tensor and its working dimensions - DijTensor refTensor = DijTensor.retrieveByName((String) referenceImage.getSelectedItem(), params.inputList); - String[] refDims = DijTensor.getWorkingDims(refTensor.form); - - for (int i = 0; i < dimValues.length; i ++) { - JLabel dimLetter1 = new JLabel(dims[i]); - JLabel dimLetter2 = new JLabel(dims[i]); - JLabel dimLetter3 = new JLabel(dims[i]); - JTextField txt1; - JTextField txt2; - JTextField txt3; - - int auxInd = params.outputList.get(outputCounter).form.indexOf(dims[i]); - - txt1 = new JTextField(params.outputList.get(outputCounter).finished ? "" + params.outputList.get(outputCounter).scale[auxInd] : "1", 5); - txt2 = new JTextField(params.outputList.get(outputCounter).finished && params.allowPatching ? "" + params.outputList.get(outputCounter).halo[auxInd] : "0", 5); - txt3 = new JTextField(params.outputList.get(outputCounter).finished ? "" + params.outputList.get(outputCounter).offset[auxInd] : "0", 5); - // Scale and offset are always editable - txt1.setEditable(true); - txt3.setEditable(true); - // If we do not allow patching, do not allow using halo - txt2.setEditable(params.allowPatching); - - int inputFixedSize = findFixedInput(refTensor, dims[i]); - - if (dimValues[i] != -1 && inputFixedSize != -1) { - float scale = ((float) dimValues[i]) / ((float) inputFixedSize); - txt1.setText("" + scale); - txt1.setEditable(true); - } else if (!Arrays.toString(refDims).contains(dims[i])) { - // If the reference image does not contain the output - //dimension, the output size for that dimension will just be - // whatever comes out of the model - txt1.setText(" - "); txt1.setEditable(false); - txt2.setText("0"); txt2.setEditable(false); - txt3.setText(" - "); txt3.setEditable(false); - } - - firstRow.place(0, i + 1, dimLetter1); - firstRow.place(1, i + 1, txt1); - secondRow.place(0, i + 1, dimLetter2); - secondRow.place(1, i + 1, txt2); - thirdRow.place(0, i + 1, dimLetter3); - thirdRow.place(1, i + 1, txt3); - - firstRowList.add(txt1); - secondRowList.add(txt2); - thirdRowList.add(txt3); - } - refLabel.setVisible(true); - referenceImage.setVisible(true); - - firstLabel.setText("Scaling factor"); - secondLabel.setText("Halo factor"); - thirdLabel.setText("Offset factor"); - - firstRow.place(0, 0, firstLabel); - secondRow.place(0, 0, secondLabel); - thirdRow.place(0, 0, thirdLabel); - - secondRow.setVisible(true); - thirdRow.setVisible(true); - firstLabel.setVisible(true); - secondLabel.setVisible(true); - thirdLabel.setVisible(true); - - firstRow.revalidate();; - firstRow.repaint(); - secondRow.revalidate();; - secondRow.repaint(); - thirdRow.revalidate();; - thirdRow.repaint(); - } - - /* - * Create Jpanel corresponding to list output - */ - private static void getPanelForList(Parameters params) { - - cmbRowList = new ArrayList>(); - - int[] dimValues = DijTensor.getWorkingDimValues(params.outputList.get(outputCounter).form, params.outputList.get(outputCounter).tensor_shape); - String[] dims = DijTensor.getWorkingDims(params.outputList.get(outputCounter).auxForm); - - String[] newDims = DijTensor.getWorkingDims(params.outputList.get(outputCounter).form); - if (params.outputList.get(outputCounter).form.contains("R") || params.outputList.get(outputCounter).form.contains("C")) { - newDims = DijTensor.getWorkingDims(params.outputList.get(outputCounter).form); - } - - for (int i = 0; i < dimValues.length; i ++) { - JLabel dimLetter1 = new JLabel(""+ dims[i] + " (size=" + dimValues[i] + ")"); - JComboBox txt1; - - txt1 = new JComboBox(new String[] {"Rows", "Columns"}); - if (newDims != null) - txt1.setSelectedIndex(newDims[i].equals("R") ? 0 : 1); - txt1.setEditable(false); - firstRow.place(0, i + 1, dimLetter1); - firstRow.place(1, i + 1, txt1); - cmbRowList.add(txt1); - } - - refLabel.setVisible(false); - referenceImage.setVisible(false); - - firstLabel.setVisible(false); - secondLabel.setVisible(false); - thirdLabel.setVisible(false); - secondRow.setVisible(false); - thirdRow.setVisible(false); - firstRow.revalidate(); - firstRow.repaint(); - } - - private static int findFixedInput(DijTensor referenceTensor, String dim) { - int fixed = -1; - if (referenceTensor != null) { - int ind = Index.indexOf(referenceTensor.form.split(""), dim); - if (ind != -1 && referenceTensor.step[ind] == 0) { - fixed = referenceTensor.minimum_size[ind]; - } - } - return fixed; - } - - public static void writeInfoText(String definition) { - HTMLPane info = new HTMLPane(Constants.width, 150); - if (definition.contains("image")) { - info.append("h2", "Model output size constraints"); - info.append("p", "model output size = model input size * scale + 2 * offset"); - info.append("p", "valid model output size -> model input size * scale + 2 * offset - 2*halo > 0"); - info.append("
    "); - info.append("li", "

    Scaling factor: the factor by which the output image " - + "dimensions are rescaled. E.g. in superresolution, if the output size " - + "is twice the size of the input, the scaling factor should be [2,2]. See the equation." - + "Scaling factor can only be positive integers for axes XYZ. In the channels axis" - + "the scaling factor can also be 0

    "); - info.append("li", "

    Offset factor: Difference between the input and output size. Note that " - + "this is different from a scaling factor. See the equation. For axes XYZ the " - + "offset does not affect the size of the final reconstructed image, it will be the" - + "as the orginal one, only applying the corresponding scaling. However, the offset" - + "of the C axis will does affect the size of the reconstructed image.

    "); - info.append("li", "

    Halo facto: Size of the receptive field of one pixel in the " - + "network used to avoid artifacts along the borders of the image. If the " - + "convolutions inside the network do not use padding, set this value to 0." - + "Halo cannot be a negative integer.

    "); - info.append("
"); - info.append("h2", "Final reconstructed image size"); - info.append("p", "The constraints defined above apply differently for" - + " the raw output of the model and for the final reconstructed image" - + "depending on the axis. For the final reconstructed image the following apply:"); - info.append("
    "); - info.append("li", "

    Final reconstructed image (X) = input image (X) * scaling factor (X)

    "); - info.append("li", "

    Final reconstructed image (Y) = input image (Y) * scaling factor (Y)

    "); - info.append("li", "

    Final reconstructed image (Z) = input image (Z) * scaling factor (Z)

    "); - info.append("li", "

    Final reconstructed image (C) = input image (C) * scaling factor (C) + 2 * offset factor (C)

    "); - info.append("
"); - info.append("p", "However, note that if patching is not allowed and the input size is fixed:"); - info.append("
    "); - info.append("li", "

    Final reconstructed image (X) = input image (X) * scaling factor (X) + 2 * offset factor (X)

    "); - info.append("li", "

    Final reconstructed image (Y) = input image (Y) * scaling factor (Y) + 2 * offset factor (Y)

    "); - info.append("li", "

    Final reconstructed image (Z) = input image (Z) * scaling factor (Z) + 2 * offset factor (Z)

    "); - info.append("li", "

    Final reconstructed image (C) = input image (C) * scaling factor (C) + 2 * offset factor (C)

    "); - info.append("
"); - } else if (definition.contains("pyramidalImage")) { - info.append("h", "Output size constraints"); - info.append("p", "Output size: Fixed output size of the model"); - } else if (definition.contains("list")) { - info.append("h", "Output size constraints"); - info.append("p", "Choose the dimension corresponding to rows and the dimension " - + "corresponding to columns."); - } - pn.add(info.getPane(), 0); - } - - @Override - public void actionPerformed(ActionEvent e) { - Parameters params = parent.getDeepPlugin().params; - - if (e.getSource() == bnNextOutput && outputCounter < (params.outputList.size() - 1)) { - if (saveOutputData(params)) { - outputCounter ++; - } - } else if (e.getSource() == bnPrevOutput && outputCounter > 0) { - outputCounter --; - } - updateInterface(params); - - } - -} diff --git a/src/main/java/deepimagej/stamp/PtSaveStamp.java b/src/main/java/deepimagej/stamp/PtSaveStamp.java deleted file mode 100755 index ec1dc3a9..00000000 --- a/src/main/java/deepimagej/stamp/PtSaveStamp.java +++ /dev/null @@ -1,378 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Color; -import java.awt.Font; -import java.awt.Frame; -import java.awt.datatransfer.DataFlavor; -import java.awt.datatransfer.Transferable; -import java.awt.datatransfer.UnsupportedFlavorException; -import java.awt.dnd.DnDConstants; -import java.awt.dnd.DropTarget; -import java.awt.dnd.DropTargetDropEvent; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; - -import javax.swing.BorderFactory; -import javax.swing.JButton; -import javax.swing.JFileChooser; -import javax.swing.JFrame; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextField; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.DeepImageJ; -import deepimagej.Parameters; -import deepimagej.components.HTMLPane; -import deepimagej.tools.FileTools; -import deepimagej.tools.YAMLUtils; -import ij.IJ; -import ij.ImagePlus; -import ij.WindowManager; -import ij.measure.ResultsTable; -import ij.text.TextWindow; - -public class PtSaveStamp extends AbstractStamp implements ActionListener, Runnable { - - private JTextField txt = new JTextField(IJ.getDirectory("imagej") + File.separator + "models" + File.separator); - private JButton bnBrowse = new JButton("Browse"); - private JButton bnSave = new JButton("Save Bundled Model"); - private HTMLPane pane; - - public PtSaveStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - public void buildPanel() { - pane = new HTMLPane(Constants.width, 320); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "Saving Bundled Model"); - JScrollPane infoPane = new JScrollPane(pane); - //infoPane.setPreferredSize(new Dimension(Constants.width, pane.getPreferredSize().height)); - DeepImageJ dp = parent.getDeepPlugin(); - - if (dp != null) - if (dp.params != null) - if (dp.params.path2Model != null) - txt.setText(dp.params.path2Model); - txt.setFont(new Font("Arial", Font.BOLD, 14)); - txt.setForeground(Color.red); - //txt.setPreferredSize(new Dimension(Constants.width, 25)); - JPanel load = new JPanel(new BorderLayout()); - load.setBorder(BorderFactory.createEtchedBorder()); - load.add(txt, BorderLayout.CENTER); - load.add(bnBrowse, BorderLayout.EAST); - - JPanel pn = new JPanel(new BorderLayout()); - pn.add(load, BorderLayout.NORTH); - pn.add(infoPane, BorderLayout.CENTER); - pn.add(bnSave, BorderLayout.SOUTH); - panel.add(pn); - - bnSave.addActionListener(this); - txt.setDropTarget(new LocalDropTarget()); - load.setDropTarget(new LocalDropTarget()); - bnBrowse.addActionListener(this); - } - - @Override - public void init() { - txt.setText(IJ.getDirectory("imagej") + File.separator + "models" + File.separator); - } - - @Override - public boolean finish() { - return true; - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == bnBrowse) { - browse(); - } - if (e.getSource() == bnSave) { - save(); - } - } - - private void browse() { - JFileChooser chooser = new JFileChooser(txt.getText()); - chooser.setFileSelectionMode(JFileChooser.DIRECTORIES_ONLY); - chooser.setDialogTitle("Select model"); - int ret = chooser.showSaveDialog(new JFrame()); - if (ret == JFileChooser.APPROVE_OPTION) { - txt.setText(chooser.getSelectedFile().getAbsolutePath()); - txt.setCaretPosition(1); - } - } - - public void save() { - Thread thread = new Thread(this); - thread.setPriority(Thread.MIN_PRIORITY); - thread.start(); - - } - @Override - public void run() { - DeepImageJ dp = parent.getDeepPlugin(); - Parameters params = dp.params; - params.biozoo = true; - params.saveDir = txt.getText() + File.separator; - params.saveDir = params.saveDir.replace(File.separator + File.separator, File.separator); - File dir = new File(params.saveDir); - - dir = new File(params.saveDir); - - if (dir.exists() && dir.isDirectory()) { - pane.append("p", "Path introduced corresponded to an already existing directory."); - pane.append("p", "Model not saved"); - IJ.error("Directory: \n" + dir.getAbsolutePath() + "\n already exists. Please introduce other name."); - return; - } - - if (!dir.exists()) { - dir.mkdir(); - pane.append("p", "Make a directory: " + params.saveDir); - } - if (!dir.exists()) { - pane.append("p", "Model not saved"); - IJ.error("This directory is not valid to save"); - return; - } - - // Save the model architecture - try { - File torchfile = new File(params.selectedModelPath); - FileTools.copyFile(torchfile, new File(dir + File.separator + "weights-torchscript" + ".pt")); - pane.append("p", "Torchscript model (.pt or .pth): saved"); - } catch (IOException e) { - e.printStackTrace(); - pane.append("p", "torchscript model (.pt or .pth): not saved"); - pane.append("p", "torchscript model (.pt or .pth): torchscript model was removed from the model path"); - } catch (Exception e) { - e.printStackTrace(); - pane.append("p", "torchscript model (.pt or .pth): not saved"); - } - - // List with processing files saved - ArrayList saved = new ArrayList(); - // Save preprocessing - if (params.firstPreprocessing != null) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.firstPreprocessing).getName()); - FileTools.copyFile(new File(params.firstPreprocessing), destFile); - pane.append("p", "First preprocessing: saved"); - saved.add(params.firstPreprocessing); - } - catch (Exception e) { - pane.append("p", "First preprocessing: not saved"); - } - } - if (params.secondPreprocessing != null && !saved.contains(params.secondPreprocessing)) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.secondPreprocessing).getName()); - FileTools.copyFile(new File(params.secondPreprocessing), destFile); - pane.append("p", "Second preprocessing: saved"); - saved.add(params.secondPreprocessing); - } - catch (Exception e) { - pane.append("p", "Second preprocessing: not saved"); - } - } else if (params.secondPreprocessing != null) { - pane.append("p", "Second preprocessing: saved"); - } - - // Save postprocessing - if (params.firstPostprocessing != null && !saved.contains(params.firstPostprocessing)) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.firstPostprocessing).getName()); - FileTools.copyFile(new File(params.firstPostprocessing), destFile); - pane.append("p", "First postprocessing: saved"); - } - catch (Exception e) { - pane.append("p", "First postprocessing: not saved"); - } - } else if (params.firstPostprocessing != null) { - pane.append("p", "First postprocessing: saved"); - } - if (params.secondPostprocessing != null && !saved.contains(params.secondPostprocessing)) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.secondPostprocessing).getName()); - FileTools.copyFile(new File(params.secondPostprocessing), destFile); - pane.append("p", "Second postprocessing: saved"); - } - catch (Exception e) { - pane.append("p", "Second postprocessing: not saved"); - } - } else if (params.secondPostprocessing != null) { - pane.append("p", "Second postprocessing: saved"); - } - - // Save input image - try { - if (params.testImageBackup != null) { - // Get name with no extension - String title = TfSaveStamp.getTitleWithoutExtension(params.testImageBackup.getTitle().substring(4)); - IJ.saveAsTiff(params.testImageBackup, params.saveDir + File.separator + title + ".tif"); - pane.append("p", title + ".tif" + ": saved"); - boolean npySaved = TfSaveStamp.saveNpyFile(params.testImageBackup, "XYCZN", params.saveDir + File.separator + title + ".npy"); - if (npySaved) - pane.append("p", title + ".npy" + ": saved"); - else - pane.append("p", title + ".npy: not saved"); - params.testImageBackup.setTitle("DUP_" + params.testImageBackup.getTitle()); - } else { - throw new Exception(); - } - } - catch(Exception ex) { - pane.append("p", "exampleImage.tif: not saved"); - pane.append("p", "exampleImage.npy: not saved"); - } - - // Save output images and tables (tables as saved as csv) - for (HashMap output : params.savedOutputs) { - String name = output.get("name"); - String nameNoExtension = TfSaveStamp.getTitleWithoutExtension(name); - try { - if (output.get("type").contains("image")) { - ImagePlus im = WindowManager.getImage(name); - IJ.saveAsTiff(im, params.saveDir + File.separator + nameNoExtension + ".tif"); - im.setTitle(name); - pane.append("p", nameNoExtension + ".tif" + ": saved"); - boolean npySaved = TfSaveStamp.saveNpyFile(im, "XYCZB", params.saveDir + File.separator + nameNoExtension + ".npy"); - if (npySaved) - pane.append("p", nameNoExtension + ".npy" + ": saved"); - else - pane.append("p", nameNoExtension + ".npy: not saved"); - } else if (output.get("type").contains("ResultsTable")){ - Frame f = WindowManager.getFrame(name); - if (f!=null && (f instanceof TextWindow)) { - ResultsTable rt = ((TextWindow)f).getResultsTable(); - rt.save(params.saveDir + File.separator + nameNoExtension + ".csv"); - pane.append("p", nameNoExtension + ".csv" + ": saved"); - boolean npySaved = TfSaveStamp.saveNpyFile(rt, params.saveDir + File.separator + nameNoExtension + ".npy", "RC"); - if (npySaved) - pane.append("p", nameNoExtension + ".npy" + ": saved"); - else - pane.append("p", nameNoExtension + ".npy: not saved"); - } else { - throw new Exception(); } - } - } - catch(Exception ex) { - pane.append("p", nameNoExtension + ".tif: not saved"); - pane.append("p", nameNoExtension + ".npy: not saved"); - } - } - - // Save yaml - try { - YAMLUtils.writeYaml(dp); - pane.append("p", "rdf.yaml: saved"); - } - catch(IOException ex) { - pane.append("p", "rdf.yaml: not saved"); - IJ.error("Model file was locked or does not exist anymore."); - } - catch(Exception ex) { - pane.append("p", "rdf.yaml: not saved"); - ex.printStackTrace(); - } - - // Finally save external dependencies - boolean saveDeps = TfSaveStamp.saveExternalDependencies(params); - if (saveDeps && params.attachments.size() > 0) { - pane.append("p", "Java .jar dependencies: saved"); - } else if (!saveDeps && params.attachments.size() > 0) { - pane.append("p", "Java .jar dependencies: not saved"); - } - - pane.append("p", "Done!!"); - - } - - public class LocalDropTarget extends DropTarget { - - @Override - public void drop(DropTargetDropEvent e) { - e.acceptDrop(DnDConstants.ACTION_COPY); - e.getTransferable().getTransferDataFlavors(); - Transferable transferable = e.getTransferable(); - DataFlavor[] flavors = transferable.getTransferDataFlavors(); - for (DataFlavor flavor : flavors) { - if (flavor.isFlavorJavaFileListType()) { - try { - List files = (List) transferable.getTransferData(flavor); - for (File file : files) { - txt.setText(file.getAbsolutePath()); - txt.setCaretPosition(1); - } - } - catch (UnsupportedFlavorException ex) { - ex.printStackTrace(); - } - catch (IOException ex) { - ex.printStackTrace(); - } - } - } - e.dropComplete(true); - super.drop(e); - } - } - - -} diff --git a/src/main/java/deepimagej/stamp/SaveOutputFilesStamp.java b/src/main/java/deepimagej/stamp/SaveOutputFilesStamp.java deleted file mode 100755 index 612d1f88..00000000 --- a/src/main/java/deepimagej/stamp/SaveOutputFilesStamp.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; -import java.awt.BorderLayout; -import java.awt.Dimension; -import java.awt.Frame; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; - -import javax.swing.DefaultListModel; -import javax.swing.DefaultListSelectionModel; -import javax.swing.JButton; -import javax.swing.JList; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.ListSelectionModel; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.components.HTMLPane; -import ij.IJ; -import ij.ImagePlus; -import ij.WindowManager; -import ij.measure.ResultsTable; -import ij.text.TextWindow; - -public class SaveOutputFilesStamp extends AbstractStamp implements ActionListener { - - private JList openedList; - private DefaultListModel listModel; - private JButton refreshBtn = new JButton("Refresh"); - private HashMap imOrResultsTable; - private HTMLPane info; - - public SaveOutputFilesStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - @Override - public void buildPanel() { - info = new HTMLPane(Constants.width, 70); - info.append("h2", "Model output selection"); - info.append("p", "Select the images and tables that you want to save in" - + " the Bundled Model as the output of an example execution."); - info.append("p", "The input is saved automatically."); - listModel = new DefaultListModel(); - listModel.addElement("example"); - openedList = new JList(listModel); - - openedList.setSelectionMode(ListSelectionModel.SINGLE_INTERVAL_SELECTION); - openedList.setLayoutOrientation(JList.VERTICAL); - openedList.setVisibleRowCount(-1); - openedList.setSelectionModel(new DefaultListSelectionModel() { - @Override - public void setSelectionInterval(int index0, int index1) { - if(super.isSelectedIndex(index0)) { - super.removeSelectionInterval(index0, index1); - } - else { - super.addSelectionInterval(index0, index1); - } - } - }); - JScrollPane listScroller = new JScrollPane(openedList); - listScroller.setPreferredSize(new Dimension(Constants.width, panel.getPreferredSize().height)); - - JPanel main = new JPanel(new BorderLayout()); - main.add(info.getPane(), BorderLayout.NORTH); - main.add(listScroller, BorderLayout.CENTER); - main.add(refreshBtn, BorderLayout.SOUTH); - panel.add(main); - - refreshBtn.addActionListener(this); - } - - @Override - public void init() { - updateOutputList(); - openedList.revalidate(); - openedList.repaint(); - - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - List selections = openedList.getSelectedValuesList(); - - if (selections.size() == 0) { - IJ.error("You need to select at least one element from the list."); - return false; - } else { - // Save the names of the outputs to save - params.savedOutputs = new ArrayList>(); - for (String name : selections) { - HashMap out = new HashMap(); - String type = imOrResultsTable.get(name); - out.put("name", name); - out.put("type", type); - - // Write info about the selected info. This info includes - // the name, the type of object and the size - String size = ""; - if (imOrResultsTable.get(name).contains("image")) { - ImagePlus im = WindowManager.getImage(name); - int[] dims = im.getDimensions(); - size = Integer.toString(dims[0]) + " x " + Integer.toString(dims[1]) + " x " + Integer.toString(dims[2]) + " x " + Integer.toString(dims[3]); - } else if (imOrResultsTable.get(name).contains("ResultsTable")) { - Frame f = WindowManager.getFrame(name); - if (f!=null && (f instanceof TextWindow)) { - ResultsTable resultstable = ((TextWindow)f).getResultsTable(); - int cols = resultstable.getLastColumn() + 1; - if (cols == 0) - cols = 1; - size = Integer.toString(resultstable.size()) + " x " + Integer.toString(resultstable.getLastColumn()); - } - } - if (!size.equals("")) { - out.put("size", size); - } else { - IJ.error("Cannot save the " + imOrResultsTable.get(name) + " " + name + - "\nbecause it has already been closed."); - updateOutputList(); - return false; - } - - - params.savedOutputs.add(out); - } - } - - return true; - } - - public void updateOutputList(){ - - imOrResultsTable = new HashMap(); - listModel = new DefaultListModel(); - // Add the elements to the list - String[] imageTitles = WindowManager.getImageTitles(); - for (String title : imageTitles) { - listModel.addElement(title); - imOrResultsTable.put(title, "image"); - } - Frame[] nonImageWindows = WindowManager.getNonImageWindows(); - for (Frame f : nonImageWindows) { - if (f!=null && (f instanceof TextWindow)) { - String tableTitle = f.getTitle(); - listModel.addElement(tableTitle); - imOrResultsTable.put(tableTitle, "ResultsTable"); - } - } - openedList.setModel(listModel); - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == refreshBtn) { - updateOutputList(); - openedList.revalidate(); - openedList.repaint(); - } - - } -} diff --git a/src/main/java/deepimagej/stamp/SelectPyramidalStamp.java b/src/main/java/deepimagej/stamp/SelectPyramidalStamp.java deleted file mode 100755 index 15f4aeac..00000000 --- a/src/main/java/deepimagej/stamp/SelectPyramidalStamp.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; -import java.awt.BorderLayout; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.Insets; - -import javax.swing.JCheckBox; -import javax.swing.JPanel; - - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.DeepImageJ; -import deepimagej.components.HTMLPane; - -public class SelectPyramidalStamp extends AbstractStamp { - - private JCheckBox checkPyramidal = new JCheckBox("Select if the model uses a Pyramidal Pooling architecture"); - private String model = ""; - private HTMLPane info; - - public SelectPyramidalStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - @Override - public void buildPanel() { - info = new HTMLPane(Constants.width, 70); - info.append("h2", "Pyramidal Feature Pooling Network selection"); - info.append("p", "Deep learning architectures that combine feature extraction and multilevel " - + "detection to infer bounding boxes. These networks are fed with more than one input " - + "and return outputs at different levels and dimensions. " - + "Examples: RetineNet, R-CNN, Fast-RCNN, Mask-RCNN or PanopticFPN."); - JPanel main = new JPanel(new BorderLayout()); - main.setLayout(new GridBagLayout()); - GridBagConstraints c = new GridBagConstraints(); - c.gridheight = 10; - c.gridx = 0; - c.gridy = 0; - c.ipadx = 70; - c.ipady = 70; - c.weightx = 1; - c.weighty = 1; - c.anchor = GridBagConstraints.NORTH; - c.fill = GridBagConstraints.HORIZONTAL; - main.add(info.getPane(), c); - c.gridheight = 1; - c.gridx = 0; - c.gridy = 10; - c.ipadx = 0; - c.ipady = 0; - c.weighty = 1; - c.anchor = GridBagConstraints.NORTH; - c.fill = GridBagConstraints.HORIZONTAL; - c.insets = new Insets(0, 50, 250, 10); - main.add(checkPyramidal, c); - panel.add(main); - checkPyramidal.setSelected(false); - } - - @Override - public void init() { - DeepImageJ dp = parent.getDeepPlugin(); - if (model.contains(dp.params.path2Model)) - checkPyramidal.setSelected(dp.params.pyramidalNetwork); - } - - @Override - public boolean finish() { - DeepImageJ dp = parent.getDeepPlugin(); - dp.params.pyramidalNetwork = false; - dp.params.allowPatching = true; - if (checkPyramidal.isSelected()) { - dp.params.pyramidalNetwork = true; - dp.params.allowPatching = false; - } - model = dp.params.path2Model; - return true; - } -} diff --git a/src/main/java/deepimagej/stamp/TensorPytorchTmpStamp.java b/src/main/java/deepimagej/stamp/TensorPytorchTmpStamp.java deleted file mode 100755 index c8c6854b..00000000 --- a/src/main/java/deepimagej/stamp/TensorPytorchTmpStamp.java +++ /dev/null @@ -1,463 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.Dimension; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.Insets; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -import javax.swing.BoxLayout; -import javax.swing.JComboBox; -import javax.swing.JLabel; -import javax.swing.JPanel; -import javax.swing.JScrollPane; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.DeepLearningModel; -import deepimagej.components.HTMLPane; -import deepimagej.tools.DijTensor; -import deepimagej.tools.Index; -import ij.IJ; -import ij.gui.GenericDialog; - -// TODO remove this class when DJL includes the functionality of finding inputs and outputs -// of a model -public class TensorPytorchTmpStamp extends AbstractStamp implements ActionListener { - - private List> inputs; - private static List> outputs; - private List> inTags; - private static List> outTags; - private String[] in = { "B", "Y", "X", "C", "Z", "-" }; - private String[] outPyramidal= { "B", "Y", "X", "C", "N/i/z", "-" }; - private String[] inputOptions= { "image", "parameter"}; - private static String[] outOptions = { "image", "list", "ignore"}; - private HTMLPane pnDim; - private JPanel pn = new JPanel(); - private JPanel pnInOut = new JPanel(); - private int iterateOverComboBox; - private String model = ""; - private boolean pyramidal = false; - - public TensorPytorchTmpStamp(BuildDialog parent) { - super(parent); - } - - @Override - public void buildPanel() { - HTMLPane info = new HTMLPane(Constants.width, 100); - info.append("h2", "Tensor Organization"); - info.append("p", "Each dimension of input and output tensors must" - + " be specified to process the image correctly, i.e. " - + "the first dimension of the input tensor corresponds to the batch size, " - + "the second dimension to the width and so on.
" - + "Note that for the moment DeepImageJ only supports BATCH_SIZE = 1."); - info.setMaximumSize(new Dimension(Constants.width, 100)); - - //pnInOut.setBorder(BorderFactory.createEtchedBorder()); - - Parameters params = parent.getDeepPlugin().params; - pnInOut.removeAll(); - List inputTensors = params.totalInputList; - List outputTensors = params.totalOutputList; - - // Set the correct information about each tensor - pnDim= new HTMLPane(Constants.width, 250); - File file = new File(parent.getDeepPlugin().params.path2Model); - String dirname = "untitled"; - if (file.exists()) - dirname = file.getName(); - // Save the model we are using to build the interface to check if - // we need to rebuild the panel or not. Same for pyramidal - model = params.path2Model; - pyramidal = params.pyramidalNetwork; - - pnDim.append("h2", "Tensor organization of " + dirname); - pnDim.append("

Input tensor types:

"); - pnDim.append("
  • Image: X for width (axis X), Y for" - + " height (axis Y), Z for depth (axis Z), B for batch, " - + "C for channel. DeepImageJ can only process one image at a time. If the " - + "input is not an image, you should include a pre-processing written in Java.

  • "); - pnDim.append("
  • Parameter: If the input is a parameter, the corresponding tensor must be " - + "created using Java pre-processing," - + " thus no dimension specification is needed. The tensor is fed directly " - + "to the model from pre-processing.

"); - - pnDim.append("

Output tensor types:

"); - if (params.pyramidalNetwork) { - pnDim.append("
  • Image: X for width (axis X), Y for" - + " height (axis Y), N/i/z for number of components/patches/objects or depth (axis Z), B for the batch, " - + "C for channel or class

  • "); - } else { - pnDim.append("
    • Image: X for width (axis X), Y for" - + " height (axis Y), Z for depth (axis Z), B for the batch, " - + "C for channel

    • "); - } - pnDim.append("
    • List: the tensor corresponds to a batch of matrices." - + "R for rows, C for columns, B for the batch. This type can be used for tensors" - + "with 3 dimensions at most (being one of them the batch).

    • "); - pnDim.append("
    • Ignore: DeepImageJ will not retrieve the tensor from the model.

    "); - pnDim.setMaximumSize(new Dimension(Constants.width, 150)); - GridBagConstraints cTag = new GridBagConstraints (); - cTag.gridwidth = 3; - cTag.gridx = 0; - cTag.insets = new Insets(3, 5, 3, 5); - GridBagConstraints cLabel = new GridBagConstraints (); - cLabel.gridwidth = 3; - cLabel.gridx = 3; - cLabel.insets = new Insets(3, 5, 3, 5); - - int nTensors = 0; - inTags = new ArrayList<>(); - outTags = new ArrayList<>(); - inputs = new ArrayList<>(); - outputs = new ArrayList<>(); - for (DijTensor input : inputTensors) { - // Create the panel that will contain all the elements for a tensor - JPanel pnTensor = new JPanel(new GridBagLayout()); - // Add the combo box to decide the type of input - JComboBox cmbInType = new JComboBox(inputOptions); - cmbInType.addActionListener(this); - pnTensor.add(cmbInType, cTag); - inTags.add(cmbInType); - // Add the name - pnTensor.add(new JLabel(input.name), cLabel); - // Now add the tensor specific dimensions - // TODO fix this when DJL adds retrieving sizes from model - for (int j = 0; j < 5; j ++) { - JComboBox cmbIn = new JComboBox(in); - cmbIn.setPreferredSize(new Dimension(50, 50)); - pnTensor.add(cmbIn); - inputs.add(cmbIn); - } - pnInOut.add(pnTensor); - nTensors ++; - } - - - for (DijTensor output : outputTensors) { - // Create the panel that will contain all the elements for a tensor - JPanel pnTensor = new JPanel(new GridBagLayout()); - // Add the combo box to decide the type of input - JComboBox cmbOutType = new JComboBox(outOptions); - cmbOutType.addActionListener(this); - pnTensor.add(cmbOutType, cTag); - outTags.add(cmbOutType); - // Add the name - pnTensor.add(new JLabel(output.name), cLabel); - // Now add the tensor specific dimensions - // TODO fix this when DJL adds retrieving sizes from model - for (int j = 0; j < 5; j ++) { - JComboBox cmbOut = new JComboBox(params.pyramidalNetwork ? outPyramidal : in); - cmbOut.setPreferredSize(new Dimension(50, 50)); - pnTensor.add(cmbOut); - outputs.add(cmbOut); - } - pnInOut.add(pnTensor); - nTensors ++; - } - JScrollPane scroll = new JScrollPane(); - pnInOut.setPreferredSize(new Dimension(500, nTensors * 60)); - scroll.setPreferredSize(new Dimension(600, nTensors * 70 + 50)); - scroll.setViewportView(pnInOut); - pn.removeAll(); - pn.setLayout(new BoxLayout(pn, BoxLayout.PAGE_AXIS)); - pn.add(info.getPane()); - pn.add(pnDim.getPane()); - pn.add(scroll); - panel.add(pn); - - - } - - @Override - public void init() { - String modelOfInterest = parent.getDeepPlugin().params.path2Model; - if (!modelOfInterest.equals(model) || pyramidal != parent.getDeepPlugin().params.pyramidalNetwork) { - buildPanel(); - return; - } - - int inpCells = inTags.size(); - int newInpCells = parent.getDeepPlugin().params.totalInputList.size(); - int outCells = outTags.size(); - int newOutCells = parent.getDeepPlugin().params.totalOutputList.size(); - if (inpCells != newInpCells || outCells != newOutCells) - buildPanel(); - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - // Parameter to make sure only one tensor corresponds to the - // image type - // TODO support several image inputs - boolean image = false; - // Reset 'allowPatching' parameter to its default value (true) - if (!params.pyramidalNetwork) - params.allowPatching = true; - params.inputList = new ArrayList(); - List inputTensors = params.totalInputList; - iterateOverComboBox = 0; - int tagC = 0; - for (DijTensor tensor : inputTensors) { - tensor.form = ""; - for (int i = iterateOverComboBox; i < iterateOverComboBox + 5; i++) { - String selection = (String) inputs.get(i).getSelectedItem(); - if (!selection.contains("-")) { - tensor.form = tensor.form + selection; - } - } - int[] shape = new int[tensor.form.length()]; - for (int i = 0; i < shape.length; i ++) { shape[i] = -1;} - tensor.tensor_shape = shape; - tensor.tensorType = (String) inTags.get(tagC ++).getSelectedItem(); - // TODO accept more than one input image - if (!image && tensor.tensorType.contains("image")) { - image = true; - } else if (tensor.tensorType.contains("image")) { - IJ.error("The current DeepImageJ version only admits on input image tensor."); - return false; - } - iterateOverComboBox += tensor.tensor_shape.length; - if (checkRepeated(tensor.form) == false && !tensor.tensorType.equals("parameter")) { - IJ.error("Dimension repetition is not allowed."); - return false; - } - if (DeepLearningModel.nBatch(tensor.tensor_shape, tensor.form).equals("1") == false && tensor.tensorType.equals("ignore") == false){ - IJ.error("The plugin only supports models with batch size (N) = 1"); - return false; - } - params.inputList.add(tensor); - } - params.outputList = new ArrayList(); - List outputTensors = params.totalOutputList; - tagC = 0; - iterateOverComboBox = 0; - for (DijTensor tensor : outputTensors) { - tensor.form = ""; - for (int i = iterateOverComboBox; i < iterateOverComboBox + 5; i++) { - String selection = (String) outputs.get(i).getSelectedItem(); - if (!selection.contains("-")) { - tensor.form = tensor.form + selection; - } - } - int[] shape = new int[tensor.form.length()]; - for (int i = 0; i < shape.length; i ++) { shape[i] = -1;} - tensor.tensor_shape = shape; - tensor.auxForm = tensor.form; - iterateOverComboBox += 5; - tensor.tensorType = (String) outTags.get(tagC ++).getSelectedItem(); - if (tensor.tensorType.contains("list")) - params.allowPatching = false; - - if (checkRepeated(tensor.form) == false && tensor.tensorType.equals("ignore") == false) { - IJ.error("Dmiension repetition is not allowed."); - return false; - } - if (DeepLearningModel.nBatch(tensor.tensor_shape, tensor.form).equals("1") == false && tensor.tensorType.equals("ignore") == false){ - IJ.error("The plugin only supports models with batch size (N) = 1"); - return false; - } - } - for (Iterator iter = outputTensors.listIterator(); iter.hasNext(); ) { - DijTensor tensor = iter.next(); - if (!tensor.tensorType.contains("ignore")) { - params.outputList.add(tensor); - } - } - if (!image) { - IJ.error("The model must have at least 1 input image."); - return false; - } - if (params.outputList.size() < 1) { - IJ.error("The model must have at least 1 output."); - return false; - } - String msg = "The model has the following dimensions:\n"; - for (DijTensor tensor : params.inputList) - msg += " - " + tensor.name + ": " + tensor.form + "\n"; - for (DijTensor tensor : params.outputList) - msg += " - " + tensor.name + ": " + tensor.form + "\n"; - msg += "\n"; - msg += "Press 'Ok' if the dimensions are correct"; - - GenericDialog dlg = new GenericDialog("Model dimensions"); - dlg.addMessage(msg); - dlg.showDialog(); - - if (dlg.wasCanceled()) - return false; - - - return true; - } - - /* - * Change the letter in each Jcombobox depending on the selected tensor type. - */ - public void updateTensorDisplay(Parameters params) { - // Set disabled the tensors marked as 'parameter' - List inputTensors = params.totalInputList; - // Counter for tensors - int cIn = 0; - int cmbCounterIn = 0; - for (JComboBox cmbTag : inTags) { - int indSelection = cmbTag.getSelectedIndex(); - String selection = inputOptions[indSelection]; - for (int i = cmbCounterIn; i < cmbCounterIn + 5; i++) { - if (selection.contains("parameter") && inputs.get(i).getItemAt(0).equals("B")) { - inputs.get(i).removeAllItems(); - inputs.get(i).addItem("-"); - inputs.get(i).setEnabled(false); - String form = inputTensors.get(cIn).form; - if (form != null && !form.contentEquals("")) { - inputTensors.get(cIn).form = ""; - } - } else if (selection.contains("image") && inputs.get(i).getItemAt(0).equals("-")) { - inputs.get(i).removeAllItems(); - inputs.get(i).addItem("B"); - inputs.get(i).addItem("Y"); - inputs.get(i).addItem("X"); - inputs.get(i).addItem("C"); - inputs.get(i).addItem("Z"); - inputs.get(i).addItem("-"); - inputs.get(i).setEnabled(true); - String form = inputTensors.get(cIn).form; - if (form != null && !form.contentEquals("")) { - inputTensors.get(cIn).form = ""; - } - } - } - cmbCounterIn += 5; - cIn ++; - } - // Set disabled the tensors marked as 'ignore' - List outputTensors = params.totalOutputList; - // Counter for tensors - int c = 0; - int cmbCounter = 0; - for (JComboBox cmbTag : outTags) { - int indSelection = cmbTag.getSelectedIndex(); - String selection = outOptions[indSelection]; - for (int i = cmbCounter; i < cmbCounter + 5; i++) { - if (selection.contains("ignore")) { - outputs.get(i).setEnabled(!selection.equals("ignore")); - } else if (selection.contains("list") && outputs.get(i).getItemAt(1).equals("Y")) { - outputs.get(i).removeAllItems(); - outputs.get(i).addItem("B"); - outputs.get(i).addItem("R"); - outputs.get(i).addItem("C"); - outputs.get(i).addItem("-"); - outputs.get(i).setEnabled(true); - //outputs.get(i).addActionListener(this); - String form = outputTensors.get(c).form; - if (form != null && !form.contentEquals("")) - outputTensors.get(c).form = ""; - } else if (selection.contains("list") && !outputs.get(i).isEnabled()) { - outputs.get(i).setEnabled(true); - } else if (selection.contains("image") && outputs.get(i).getItemAt(1).equals("R")) { - outputs.get(i).removeAllItems(); - outputs.get(i).addItem("B"); - outputs.get(i).addItem("Y"); - outputs.get(i).addItem("X"); - outputs.get(i).addItem("C"); - outputs.get(i).addItem("Z"); - outputs.get(i).addItem("-"); - outputs.get(i).setEnabled(true); - //outputs.get(i).addActionListener(this); - String form = outputTensors.get(c).form; - if (form != null && !form.contentEquals("")) - outputTensors.get(c).form = ""; - } else if (selection.contains("image") && outputs.get(i).getItemAt(1).equals("R") && !params.pyramidalNetwork) { - outputs.get(i).removeAllItems(); - outputs.get(i).addItem("B"); - outputs.get(i).addItem("Y"); - outputs.get(i).addItem("X"); - outputs.get(i).addItem("C"); - outputs.get(i).addItem("N/i/z"); - outputs.get(i).addItem("-"); - outputs.get(i).setEnabled(true); - //outputs.get(i).addActionListener(this); - String form = outputTensors.get(c).form; - if (form != null && !form.contentEquals("")) - outputTensors.get(c).form = ""; - } else if (selection.contains("image") && !outputs.get(i).isEnabled()) { - outputs.get(i).setEnabled(true); - } - } - cmbCounter += 5; - c ++; - } - } - - private boolean checkRepeated(String form) { - // This method checks if the form given by the user - // has not repeated dimensions. If it has them, it throws - // an exception to alert the user. - for (int pos = 0; pos < form.length(); pos++) { - int last_index = Index.lastIndexOf(form.split(""), form.split("")[pos]); - if (last_index != pos) { - return false; - } - } - return true; - } - - @Override - public void actionPerformed(ActionEvent e) { - Parameters params = parent.getDeepPlugin().params; - updateTensorDisplay(params); - } -} diff --git a/src/main/java/deepimagej/stamp/TensorStamp.java b/src/main/java/deepimagej/stamp/TensorStamp.java deleted file mode 100755 index 77b6647b..00000000 --- a/src/main/java/deepimagej/stamp/TensorStamp.java +++ /dev/null @@ -1,429 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.Dimension; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.Insets; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; - -import javax.swing.BoxLayout; -import javax.swing.JComboBox; -import javax.swing.JLabel; -import javax.swing.JPanel; -import javax.swing.JScrollPane; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.Parameters; -import deepimagej.DeepLearningModel; -import deepimagej.components.HTMLPane; -import deepimagej.tools.DijTensor; -import deepimagej.tools.Index; -import ij.IJ; - -public class TensorStamp extends AbstractStamp implements ActionListener { - - private List> inputs; - private static List> outputs; - private List> inTags; - private static List> outTags; - private String[] in = { "B", "Y", "X", "C", "Z" }; - private String[] outPyramidal= { "B", "Y", "X", "C", "N/i/z" }; - private String[] inputOptions= { "image", "parameter"}; - private static String[] outOptions = { "image", "list", "ignore"}; - private HTMLPane pnDim; - private JPanel pn = new JPanel(); - private JPanel pnInOut = new JPanel(); - private int iterateOverComboBox; - private String model = ""; - private boolean pyramidal = false; - - public TensorStamp(BuildDialog parent) { - super(parent); - } - - @Override - public void buildPanel() { - HTMLPane info = new HTMLPane(Constants.width, 100); - info.append("h2", "Tensor Organization"); - info.append("p", "Each dimension of input and output tensors must" - + " be specified to process the image correctly, i.e. " - + "the first dimension of the input tensor corresponds to the batch size, " - + "the second dimension to the width and so on.
    " - + "Note that for the moment DeepImageJ only supports BATCH_SIZE = 1."); - info.setMaximumSize(new Dimension(Constants.width, 100)); - - //pnInOut.setBorder(BorderFactory.createEtchedBorder()); - - Parameters params = parent.getDeepPlugin().params; - pnInOut.removeAll(); - List inputTensors = params.totalInputList; - List outputTensors = params.totalOutputList; - - // Set the correct information about each tensor - pnDim= new HTMLPane(Constants.width, 250); - File file = new File(parent.getDeepPlugin().params.path2Model); - String dirname = "untitled"; - if (file.exists()) - dirname = file.getName(); - pnDim.append("h2", "Input/Output tensor dimensions of " + dirname); - // Save the model we are using to build the interface to check if - // we need to rebuild the panel or not. Same for pyramidal. - model = params.path2Model; - pyramidal = params.pyramidalNetwork; - - for (DijTensor tensor : inputTensors) - pnDim.append("

    Input tensor --> " + tensor.name + " : " + Arrays.toString(tensor.tensor_shape) + "

    "); - for (DijTensor tensor : outputTensors) - pnDim.append("

    Output tensor --> " + tensor.name + " : " + Arrays.toString(tensor.tensor_shape) + "

    "); - - pnDim.append("
    "); - pnDim.append("

    Input tensor types:

    "); - pnDim.append("
    • Image: X for width (axis X), Y for" - + " height (axis Y), Z for depth (axis Z), B for batch, " - + "C for channel. DeepImageJ can only process one image at a time. If the input" - + " is not an image, you should include a pre-processing written in Java.

    • "); - pnDim.append("
    • Parameter: If the input is a parameter, the corresponding tensor must be " - + "created using Java pre-processing," - + " thus no dimension specification is needed. The tensor is fed directly " - + "to the model from pre-processing.

    "); - - pnDim.append("

    Output tensor types:

    "); - if (params.pyramidalNetwork) { - pnDim.append("
    • Image: X for width (axis X), Y for" - + " height (axis Y), N/i/z for number of components/patches/objects or depth (axis Z), B for the batch, " - + "C for channel or class

    • "); - } else { - pnDim.append("
      • Image: X for width (axis X), Y for" - + " height (axis Y), Z for depth (axis Z), B for the batch, " - + "C for channel

      • "); - } - pnDim.append("
      • List: the tensor corresponds to a batch of matrices." - + "R for rows, C for columns, B for the batch. This type can be used for tensors" - + "with 3 dimensions at most (being one of them the batch).

      • "); - pnDim.append("
      • Ignore: DeepImageJ will not retrieve the tensor from the model.

      "); - - pnDim.setMaximumSize(new Dimension(Constants.width, 250)); - //JPanel pnInput = new JPanel(new GridLayout(1, 5)); - GridBagConstraints cTag = new GridBagConstraints (); - cTag.gridwidth = 3; - cTag.gridx = 0; - cTag.insets = new Insets(3, 5, 3, 5); - GridBagConstraints cLabel = new GridBagConstraints (); - cLabel.gridwidth = 3; - cLabel.gridx = 3; - cLabel.insets = new Insets(3, 5, 3, 5); - - int nTensors = 0; - inTags = new ArrayList<>(); - outTags = new ArrayList<>(); - inputs = new ArrayList<>(); - outputs = new ArrayList<>(); - for (DijTensor input : inputTensors) { - // Create the panel that will contain all the elements for a tensor - JPanel pnTensor = new JPanel(new GridBagLayout()); - // Add the combo box to decide the type of input - JComboBox cmbInType = new JComboBox(inputOptions); - cmbInType.addActionListener(this); - pnTensor.add(cmbInType, cTag); - inTags.add(cmbInType); - // Add the name - pnTensor.add(new JLabel(input.name), cLabel); - // Now add the tensor specific dimensions - for (int j = 0; j < input.tensor_shape.length; j ++) { - JComboBox cmbIn = new JComboBox(in); - cmbIn.setPreferredSize(new Dimension(50, 50)); - pnTensor.add(cmbIn); - inputs.add(cmbIn); - } - pnInOut.add(pnTensor); - nTensors ++; - } - - - for (DijTensor output : outputTensors) { - // Create the panel that will contain all the elements for a tensor - JPanel pnTensor = new JPanel(new GridBagLayout()); - // Add the combo box to decide the type of input - JComboBox cmbOutType = new JComboBox(outOptions); - cmbOutType.addActionListener(this); - pnTensor.add(cmbOutType, cTag); - outTags.add(cmbOutType); - // Add the name - pnTensor.add(new JLabel(output.name), cLabel); - // Now add the tensor specific dimensions - for (int j = 0; j < output.tensor_shape.length; j ++) { - JComboBox cmbOut = new JComboBox(params.pyramidalNetwork ? outPyramidal : in); - cmbOut.setPreferredSize(new Dimension(50, 50)); - pnTensor.add(cmbOut); - outputs.add(cmbOut); - } - pnInOut.add(pnTensor); - nTensors ++; - } - JScrollPane scroll = new JScrollPane(); - pnInOut.setPreferredSize(new Dimension(500, nTensors * 60)); - scroll.setPreferredSize(new Dimension(600, nTensors * 70 + 50)); - scroll.setViewportView(pnInOut); - pn.removeAll(); - pn.setLayout(new BoxLayout(pn, BoxLayout.PAGE_AXIS)); - pn.add(info.getPane()); - pn.add(pnDim.getPane()); - pn.add(scroll); - panel.add(pn); - - - } - - @Override - public void init() { - String modelOfInterest = parent.getDeepPlugin().params.path2Model; - if (!modelOfInterest.equals(model) || pyramidal != parent.getDeepPlugin().params.pyramidalNetwork) { - buildPanel(); - } - } - - @Override - public boolean finish() { - Parameters params = parent.getDeepPlugin().params; - // Parameter to make sure only one tensor corresponds to the - // image type - // TODO support several image inputs - boolean image = false; - // Reset 'allowPatching' parameter to its default value (true) - if (!params.pyramidalNetwork) - params.allowPatching = true; - params.inputList = new ArrayList(); - List inputTensors = params.totalInputList; - iterateOverComboBox = 0; - int tagC = 0; - for (DijTensor tensor : inputTensors) { - tensor.form = ""; - for (int i = iterateOverComboBox; i < iterateOverComboBox + tensor.tensor_shape.length; i++) - tensor.form = tensor.form + (String) inputs.get(i).getSelectedItem(); - tensor.tensorType = (String) inTags.get(tagC ++).getSelectedItem(); - // TODO accept more than one input image - if (!image && tensor.tensorType.contains("image")) { - image = true; - } else if (tensor.tensorType.contains("image")) { - IJ.error("The current DeepImageJ version only admits on input image tensor."); - return false; - } - iterateOverComboBox += tensor.tensor_shape.length; - if (checkRepeated(tensor.form) == false && tensor.tensorType.equals("parameter") == false) { - IJ.error("Dimension repetition is not allowed"); - return false; - } - if (DeepLearningModel.nBatch(tensor.tensor_shape, tensor.form).equals("1") == false && tensor.tensorType.equals("ignore") == false){ - IJ.error("The plugin only supports models with batch size (N) = 1"); - return false; - } - params.inputList.add(tensor); - } - params.outputList = new ArrayList(); - List outputTensors = params.totalOutputList; - tagC = 0; - iterateOverComboBox = 0; - for (DijTensor tensor : outputTensors) { - tensor.form = ""; - for (int i = iterateOverComboBox; i < iterateOverComboBox + tensor.tensor_shape.length; i++) { - String dimensionLetter = (String) outputs.get(i).getSelectedItem(); - dimensionLetter = dimensionLetter.toLowerCase().contains("z") ? "Z" : dimensionLetter; - tensor.form = tensor.form + dimensionLetter; - } - tensor.auxForm = tensor.form; - iterateOverComboBox += tensor.tensor_shape.length; - tensor.tensorType = (String) outTags.get(tagC ++).getSelectedItem(); - if (tensor.tensorType.contains("list")) { - params.allowPatching = false; - } - if (checkRepeated(tensor.form) == false && tensor.tensorType.equals("ignore") == false) { - IJ.error("Dimension repetition is not allowed"); - return false; - } - if (DeepLearningModel.nBatch(tensor.tensor_shape, tensor.form).equals("1") == false && tensor.tensorType.equals("ignore") == false){ - IJ.error("The plugin only supports models with batch size (B) = 1"); - return false; - } - } - for (Iterator iter = outputTensors.listIterator(); iter.hasNext(); ) { - DijTensor tensor = iter.next(); - if (!tensor.tensorType.contains("ignore")) { - params.outputList.add(tensor); - } - } - if (!image) { - IJ.error("The model must have at least 1 input image."); - return false; - } - if (params.outputList.size() < 1) { - IJ.error("The model must have at least 1 output."); - return false; - } - - return true; - } - - /* - * Change the letter in each Jcombobox depending on the selected tensor type. - */ - public void updateTensorDisplay(Parameters params) { - // Set disabled the tensors marked as 'parameter' - List inputTensors = params.totalInputList; - // Counter for tensors - int cIn = 0; - int cmbCounterIn = 0; - for (JComboBox cmbTag : inTags) { - int indSelection = cmbTag.getSelectedIndex(); - String selection = inputOptions[indSelection]; - for (int i = cmbCounterIn; i < cmbCounterIn + inputTensors.get(cIn).tensor_shape.length; i++) { - if (selection.contains("parameter") && inputs.get(i).getItemAt(0).equals("B")) { - inputs.get(i).removeAllItems(); - inputs.get(i).addItem("-"); - inputs.get(i).setEnabled(false); - String form = inputTensors.get(cIn).form; - if (form != null && !form.contentEquals("")) { - inputTensors.get(cIn).form = ""; - } - } else if (selection.contains("image") && inputs.get(i).getItemAt(0).equals("-")) { - inputs.get(i).removeAllItems(); - inputs.get(i).addItem("B"); - inputs.get(i).addItem("Y"); - inputs.get(i).addItem("X"); - inputs.get(i).addItem("C"); - inputs.get(i).addItem("Z"); - inputs.get(i).setEnabled(true); - String form = inputTensors.get(cIn).form; - if (form != null && !form.contentEquals("")) { - inputTensors.get(cIn).form = ""; - } - } - } - cmbCounterIn += inputTensors.get(cIn).tensor_shape.length; - cIn ++; - } - // Set disabled the tensors marked as 'ignore' - List outputTensors = params.totalOutputList; - // Counter for tensors - int c = 0; - int cmbCounter = 0; - for (JComboBox cmbTag : outTags) { - int indSelection = cmbTag.getSelectedIndex(); - String selection = outOptions[indSelection]; - for (int i = cmbCounter; i < cmbCounter + outputTensors.get(c).tensor_shape.length; i++) { - if (selection.contains("ignore")) { - outputs.get(i).setEnabled(!selection.equals("ignore")); - } else if (selection.contains("list") && outputs.get(i).getItemAt(1).equals("Y")) { - outputs.get(i).removeAllItems(); - outputs.get(i).addItem("B"); - outputs.get(i).addItem("R"); - outputs.get(i).addItem("C"); - outputs.get(i).setEnabled(true); - //outputs.get(i).addActionListener(this); - String form = outputTensors.get(c).form; - if (form != null && !form.contentEquals("")) - outputTensors.get(c).form = ""; - } else if (selection.contains("list") && !outputs.get(i).isEnabled()) { - outputs.get(i).setEnabled(true); - } else if (selection.contains("image") && outputs.get(i).getItemAt(1).equals("R") && !params.pyramidalNetwork) { - outputs.get(i).removeAllItems(); - outputs.get(i).addItem("B"); - outputs.get(i).addItem("Y"); - outputs.get(i).addItem("X"); - outputs.get(i).addItem("C"); - outputs.get(i).addItem("Z"); - outputs.get(i).setEnabled(true); - //outputs.get(i).addActionListener(this); - String form = outputTensors.get(c).form; - if (form != null && !form.contentEquals("")) - outputTensors.get(c).form = ""; - } else if (selection.contains("image") && outputs.get(i).getItemAt(1).equals("R") && params.pyramidalNetwork) { - outputs.get(i).removeAllItems(); - outputs.get(i).addItem("B"); - outputs.get(i).addItem("Y"); - outputs.get(i).addItem("X"); - outputs.get(i).addItem("C"); - outputs.get(i).addItem("N/i/z"); - outputs.get(i).setEnabled(true); - //outputs.get(i).addActionListener(this); - String form = outputTensors.get(c).form; - if (form != null && !form.contentEquals("")) - outputTensors.get(c).form = ""; - } else if (selection.contains("image") && !outputs.get(i).isEnabled()) { - outputs.get(i).setEnabled(true); - } - } - cmbCounter += outputTensors.get(c).tensor_shape.length; - c ++; - } - } - - private boolean checkRepeated(String form) { - // This method checks if the form given by the user - // has not repeated dimensions. If it has them, it throws - // an exception to alert the user. - for (int pos = 0; pos < form.length(); pos++) { - int last_index = Index.lastIndexOf(form.split(""), form.split("")[pos]); - if (last_index != pos) { - return false; - } - } - return true; - } - - @Override - public void actionPerformed(ActionEvent e) { - Parameters params = parent.getDeepPlugin().params; - updateTensorDisplay(params); - } -} diff --git a/src/main/java/deepimagej/stamp/TestStamp.java b/src/main/java/deepimagej/stamp/TestStamp.java deleted file mode 100755 index 65feae26..00000000 --- a/src/main/java/deepimagej/stamp/TestStamp.java +++ /dev/null @@ -1,711 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; -import java.awt.BorderLayout; -import java.awt.Component; -import java.awt.Container; -import java.awt.GridLayout; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.awt.event.MouseEvent; -import java.awt.event.MouseListener; -import java.io.File; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -import javax.swing.BorderFactory; -import javax.swing.BoxLayout; -import javax.swing.JButton; -import javax.swing.JComboBox; -import javax.swing.JLabel; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextField; -import javax.swing.SwingUtilities; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.DeepImageJ; -import deepimagej.Parameters; -import deepimagej.RunnerTf; -import deepimagej.RunnerProgress; -import deepimagej.RunnerPt; -import deepimagej.components.HTMLPane; -import deepimagej.tools.ArrayOperations; -import deepimagej.tools.DijRunnerPostprocessing; -import deepimagej.tools.DijRunnerPreprocessing; -import deepimagej.tools.DijTensor; -import deepimagej.tools.Log; -import ij.IJ; -import ij.ImagePlus; -import ij.WindowManager; - -public class TestStamp extends AbstractStamp implements ActionListener, MouseListener, Runnable { - - private HTMLPane pnTest; - private JButton bnTest = new JButton("Run a test"); - private JTextField axesTxt = new JTextField("C,Y,X"); - private JTextField sizeTxt = new JTextField("3,256,256"); - private List> cmbList = new ArrayList>(); - private List btnList = new ArrayList(); - private JPanel inputsPn = new JPanel(new GridLayout(3, 1)); - private String selectedImage = ""; - - private List imageTensors; - - public TestStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - public void buildPanel() { - SwingUtilities.invokeLater(() -> { - - pnTest = new HTMLPane(Constants.width, 100); - JScrollPane pnTestScroller = new JScrollPane(pnTest); - //pnTestScroller.setPreferredSize(new Dimension(Constants.width, pnTest.getPreferredSize().height)); - HTMLPane pane = new HTMLPane(Constants.width, 100); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "Run a test on an image"); - pane.append("p", "Select an input image."); - pane.append("p", "Introduce an image size that can be accepted by the model."); - pane.append("p", "The tile size will be used together with the parameters\n" - + "previously introduced to process the whole image."); - pane.append("p", "Take into account that if you are using CPU, the images\n" - + "processed at once cannot be too big due to memory limitations."); - pane.append("p", "The smallest size that allows processing the whole image with\n" - + "only 1 tile and that fulfils the parameters in suggested\n" - + "automatically"); - pane.append("p", "After setting the tile size for the test run, click on Run a test"); - - JPanel pn1 = new JPanel(); - pn1.setLayout(new BoxLayout(pn1, BoxLayout.Y_AXIS)); - JComboBox cmb = new JComboBox(); - cmbList.add(cmb); - JButton btn = retrieveJComboBoxArrow(cmb); - btnList.add(btn); - JPanel firstPn = new JPanel(); - firstPn.setLayout(new BoxLayout(firstPn, BoxLayout.LINE_AXIS)); - firstPn.setBorder(BorderFactory.createEmptyBorder(0, 10, 10, 10)); - //JPanel firstPn = new JPanel(new GridLayout(1, 2)); - firstPn.add(new JLabel("input")); - firstPn.add(cmb); - //JPanel secondPn = new JPanel(new GridLayout(1, 2)); - JPanel secondPn = new JPanel(); - secondPn.setLayout(new BoxLayout(secondPn, BoxLayout.LINE_AXIS)); - secondPn.setBorder(BorderFactory.createEmptyBorder(0, 10, 10, 10)); - secondPn.add(new JLabel("Axes")); - secondPn.add(axesTxt); - //JPanel thirdPn = new JPanel(new GridLayout(1, 2)); - JPanel thirdPn = new JPanel(); - thirdPn.setLayout(new BoxLayout(thirdPn, BoxLayout.LINE_AXIS)); - thirdPn.setBorder(BorderFactory.createEmptyBorder(0, 10, 10, 10)); - thirdPn.add(new JLabel("Input tile size")); - thirdPn.add(sizeTxt); - inputsPn.add(firstPn); - inputsPn.add(secondPn); - inputsPn.add(thirdPn); - pn1.add(inputsPn, BorderLayout.CENTER); - pn1.add(bnTest, BorderLayout.SOUTH); - pn1.setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10)); - - JPanel pnt = new JPanel(); - pnt.setLayout(new BoxLayout(pnt, BoxLayout.PAGE_AXIS)); - pnt.add(pane.getPane()); - pnt.add(pn1); - - JPanel pn = new JPanel(new BorderLayout()); - pn.add(pnt, BorderLayout.NORTH); - pn.add(pnTestScroller, BorderLayout.CENTER); - - pnTest.setEnabled(true); - - panel.add(pn); - bnTest.addActionListener(this); - }); - } - - @Override - public void init() { - Parameters params = parent.getDeepPlugin().params; - inputsPn.removeAll(); - imageTensors = DijTensor.getImageTensors(params.inputList); - //inputsPn.setLayout(new GridLayout(3, 1)); - inputsPn.setLayout(new BoxLayout(inputsPn, BoxLayout.PAGE_AXIS)); - cmbList = new ArrayList>(); - btnList = new ArrayList(); - JComboBox cmb = new JComboBox(); - cmb = new JComboBox(); - String[] titlesList = WindowManager.getImageTitles(); - int c = 0; - for (DijTensor tensor : imageTensors) { - cmb = new JComboBox(); - if (titlesList.length != 0) { - for (String title : titlesList) - cmb.addItem(title); - cmbList.add(cmb); - bnTest.setEnabled(parent.getDeepPlugin() != null); - } else { - bnTest.setEnabled(false); - params.testImageBackup = null; - cmb.addItem("No image"); - cmbList.add(cmb); - } - - JPanel firstPn = new JPanel(); - firstPn.setLayout(new BoxLayout(firstPn, BoxLayout.LINE_AXIS)); - firstPn.setBorder(BorderFactory.createEmptyBorder(0, 0, 10, 10)); - JLabel lab1 = new JLabel(tensor.name); - firstPn.add(lab1); - firstPn.add(cmb); - lab1.setBorder(BorderFactory.createEmptyBorder(0, 10, 0, 10)); - JPanel secondPn = new JPanel(); - secondPn.setLayout(new BoxLayout(secondPn, BoxLayout.LINE_AXIS)); - secondPn.setBorder(BorderFactory.createEmptyBorder(0, 0, 10, 10)); - JLabel lab2 = new JLabel("Axes"); - secondPn.add(lab2); - secondPn.add(axesTxt); - lab2.setBorder(BorderFactory.createEmptyBorder(0, 10, 0, 10)); - setAxes(c); - JPanel thirdPn = new JPanel(); - thirdPn.setLayout(new BoxLayout(thirdPn, BoxLayout.LINE_AXIS)); - thirdPn.setBorder(BorderFactory.createEmptyBorder(0, 0, 10, 10)); - JLabel lab3 = new JLabel("Input tile size"); - thirdPn.add(lab3); - thirdPn.add(sizeTxt); - lab3.setBorder(BorderFactory.createEmptyBorder(0, 10, 0, 10)); - setOptimalPatch((String) cmb.getSelectedItem(), c); - inputsPn.add(firstPn); - inputsPn.add(secondPn); - inputsPn.add(thirdPn); - - btnList.add(retrieveJComboBoxArrow(cmb)); - btnList.get(c).addMouseListener(this); - cmbList.get(c).addMouseListener(this); - cmbList.get(c ++).addActionListener(this); - } - } - - @Override - public boolean finish() { - return true; - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == bnTest) { - // Check if all the input images are associated to an image - // opened in ImageJ - String[] titlesList = WindowManager.getImageTitles(); - for (int j = 0; j < cmbList.size(); j ++) { - String selectedOption = (String) cmbList.get(j).getSelectedItem(); - if (Arrays.asList(titlesList).contains(selectedOption)) - continue; - else { - IJ.error("Select images open in ImageJ"); - return; - } - } - // If all the images selected are opened in ImageJ, perform a test run - if (!testPreparation()) - return; - Thread thread = new Thread(this); - thread.setPriority(Thread.MIN_PRIORITY); - thread.start(); - } - - for (int i = 0; i < cmbList.size(); i ++) { - if (e.getSource() == cmbList.get(i)) { - // If all the selected items in every cmbBox correspond - // to an existing image, set the button enabled, if not, - // not enabled - cmbList.get(0).removeActionListener(this); - String[] titlesList = WindowManager.getImageTitles(); - for (int j = 0; j < cmbList.size(); j ++) { - String selectedOption = (String) cmbList.get(j).getSelectedItem(); - if (Arrays.asList(titlesList).contains(selectedOption)) { - setOptimalPatch(selectedOption, j); - continue; - } else { - bnTest.setEnabled(false); - return; - } - } - bnTest.setEnabled(true); - cmbList.get(0).addActionListener(this); - - break; - } - } - - // If the selected image has changed, update the patch size - // TODO generalise for several inputa images - String newSelectedImage = (String) cmbList.get(0).getSelectedItem(); - if (!selectedImage.contentEquals(newSelectedImage)) { - selectedImage = newSelectedImage; - if (Arrays.asList(WindowManager.getImageTitles()).contains(selectedImage)) - setOptimalPatch(selectedImage, 0); - } - } - - /* - * Prepare the inputs and check that the image complies with the - * model requirements - * Return false if something is wrong - */ - public boolean testPreparation() { - Parameters params = parent.getDeepPlugin().params; - - File file = new File(params.path2Model); - if (!file.exists()) { - IJ.error("The model was removed from its original location."); - return false; - } - String dirname = file.getName(); - bnTest.setEnabled(false); - pnTest.append("h2", "Test " + dirname); - - String[] images = new String[imageTensors.size()]; - for (int i = 0; i < images.length; i++) { - images[i] = (String)cmbList.get(i).getSelectedItem(); - // TODO generalise for several input images - String[] dims = DijTensor.getWorkingDims(imageTensors.get(i).form); - int[] tileSize = ArrayOperations.getPatchSize(dims, imageTensors.get(i).form, sizeTxt.getText(), sizeTxt.isEditable()); - boolean isTileCorrect = checkInputTileSize(tileSize, imageTensors.get(i).name, params); - if (tileSize == null || !isTileCorrect) - return false; - imageTensors.get(i).recommended_patch = tileSize; - } - if (images.length == 1) - params.testImage = WindowManager.getImage(images[0]); - String imagesNames = Arrays.toString(images); - - for (String im : images) { - if (WindowManager.getImage(im) == null) { - pnTest.append("p", im + " does not correspond to an open image"); - IJ.error("No selected test image."); - return false; - } - params.testImageBackup = WindowManager.getImage(im).duplicate(); - params.testImageBackup.setTitle("DUP_" + im); - } - - pnTest.append("Selected input images " + imagesNames); - return true; - } - - @Override - public void run() { - test(); - - } - - /*// TODO create methods to group code in charge of stopping the execution - * Perform a test run on the selected image - */ - public void test() { - String runnerError = ""; - ExecutorService service = Executors.newFixedThreadPool(1); - DeepImageJ dp = parent.getDeepPlugin(); - RunnerProgress rp = new RunnerProgress(dp, "preprocessing"); - String step = "pre"; - - try { - DijRunnerPreprocessing preprocess = new DijRunnerPreprocessing(dp, rp, null, false, true); - Future> f0 = service.submit(preprocess); - HashMap inputsMap = f0.get(); - if (rp.isStopped()) { - RunnerProgress.stopRunnerProgress(service, rp); - pnTest.append("p", "Test run was stoped during preprocessing."); - IJ.error("Test run was stoped during preprocessing."); - // Remove possible hidden images from IJ workspace - ArrayOperations.removeProcessedInputsFromMemory(inputsMap); - return; - } else if (inputsMap == null && preprocess.error.contentEquals("")) { - RunnerProgress.stopRunnerProgress(service, rp); - pnTest.append("p", "Error during preprocessing."); - pnTest.append("p", "The preprocessing did not return anything."); - // Remove possible hidden images from IJ workspace - ArrayOperations.removeProcessedInputsFromMemory(inputsMap); - return; - } else if (!preprocess.error.contentEquals("")) { - RunnerProgress.stopRunnerProgress(service, rp); - pnTest.append("p", preprocess.error); - // Remove possible hidden images from IJ workspace - ArrayOperations.removeProcessedInputsFromMemory(inputsMap); - return; - } - - step = "model"; - HashMap output = null; - if (dp.params.framework.equals("tensorflow")) { - rp.setGPU(parent.getGPUTf()); - RunnerTf runner = new RunnerTf(dp, rp, inputsMap, new Log()); - rp.setRunner(runner); - // TODO decide what to store at the end of the execution - Future> f1 = service.submit(runner); - output = f1.get(); - runnerError = runner.error; - } else { - rp.setGPU(parent.getGPUPt()); - RunnerPt runner = new RunnerPt(dp, rp, inputsMap, new Log()); - rp.setRunner(runner); - // TODO decide what to store at the end of the execution - Future> f1 = service.submit(runner); - output = f1.get(); - runnerError = runner.error; - } - - if (output == null && !rp.isStopped()) { - RunnerProgress.stopRunnerProgress(service, rp); - pnTest.append("p", "Test run failed"); - pnTest.append("p", runnerError); - IJ.error("The execution of the model failed."); - // Remove possible hidden images from IJ workspace - ArrayOperations.removeProcessedInputsFromMemory(inputsMap); - return; - } else if (rp.isStopped()) { - RunnerProgress.stopRunnerProgress(service, rp); - pnTest.append("p", "Model execution of the test run stopped"); - IJ.error("Model execution of the test run stopped."); - // Remove possible hidden images from IJ workspace - ArrayOperations.removeProcessedInputsFromMemory(inputsMap); - return; - } - - step = "post"; - DijRunnerPostprocessing postprocess = new DijRunnerPostprocessing(dp, rp, output); - Future> f2 = service.submit(postprocess); - output = f2.get(); - - if (rp.isStopped()) { - pnTest.append("p", "Test run was stoped during postprocessing."); - IJ.error("Test run was stoped during postprocessing."); - } - - RunnerProgress.stopRunnerProgress(service, rp); - // Print the outputs of the postprocessing - // Retrieve the opened windows and compare them to what the model has outputed - // Display only what has not already been displayed - - String[] finalFrames = WindowManager.getNonImageTitles(); - String[] finalImages = WindowManager.getImageTitles(); - ArrayOperations.displayMissingOutputs(finalImages, finalFrames, output); - // Remove possible hidden images from IJ workspace - ArrayOperations.removeProcessedInputsFromMemory(inputsMap); - - parent.endsTest(); - bnTest.setEnabled(true); - pnTest.append("p", "Peak memory:" + dp.params.memoryPeak); - dp.params.runtime = rp.getRuntime(); - pnTest.append("p", "Runtime: " + dp.params.runtime + "s"); - - } catch (InterruptedException | ExecutionException e) { - e.printStackTrace(); - if (step.contains("pre")) { - pnTest.append("p", "Thread stopped working during the preprocessing.\n" - + "The reason might be a faulty preprocessing"); - IJ.error("Thread stopped working during the preprocessing.\n" - + "The reason might be a faulty preprocessing"); - } else if (step.contains("model")) { - pnTest.append("p", "Thread stopped working during the execution of the model."); - IJ.error("Thread stopped working during the execution of the model."); - } else if (step.contains("post")) { - pnTest.append("p", "Thread stopped working during the execution of the model.\n" - + "The reason might be a faulty postprocessing"); - IJ.error("Thread stopped working during the execution of the model.\n" - + "The reason might be a faulty postprocessing"); - } - } - RunnerProgress.stopRunnerProgress(service, rp); - } - - public JButton retrieveJComboBoxArrow(Container container) { - if (container instanceof JButton) { - return (JButton) container; - } else { - Component[] components = container.getComponents(); - for (Component component : components) { - if (component instanceof Container) { - return retrieveJComboBoxArrow((Container)component); - } - } - } - return null; - } - - /* - * This method sets the axes specified by the user separated by commas - */ - private void setAxes(int imageTensorInd) { - DijTensor tensor = imageTensors.get(imageTensorInd); - // Get basic information about the input from the yaml - String tensorForm = tensor.form; - String[] dim = DijTensor.getWorkingDims(tensorForm); - - String axesAux = ""; - for (String dd : dim) {axesAux += dd + ",";} - axesTxt.setText(axesAux.substring(0, axesAux.length() - 1)); - axesTxt.setEditable(false); - - } - - /* - * This method calculates an acceptable input tile size to the model - * considering the image selected size and the parameters set previously - * by the user - */ - private void setOptimalPatch(String selectedOption, int imageTensorInd) { - ImagePlus imp = WindowManager.getImage(selectedOption); - DijTensor tensor = imageTensors.get(imageTensorInd); - // Get basic information about the input from the yaml - String tensorForm = tensor.form; - // Minimum size if it is not fixed, 0s if it is - int[] tensorMin = tensor.minimum_size; - // Step if the size is not fixed, 0s if it is - int[] tensorStep = tensor.step; - float[] haloSize = ArrayOperations.findTotalPadding(tensor, parent.getDeepPlugin().params.outputList, parent.getDeepPlugin().params.pyramidalNetwork); - int[] min = DijTensor.getWorkingDimValues(tensorForm, tensorMin); - int[] step = DijTensor.getWorkingDimValues(tensorForm, tensorStep); - float[] haloVals = DijTensor.getWorkingDimValues(tensorForm, haloSize); - String[] dim = DijTensor.getWorkingDims(tensorForm); - - String optimalPatch = ArrayOperations.optimalPatch(imp, haloVals, dim, step, min, parent.getDeepPlugin().params.allowPatching); - - sizeTxt.setText(optimalPatch); - int auxFixed = 0; - for (int ss : step) - auxFixed += ss; - - sizeTxt.setEditable(true); - if (!parent.getDeepPlugin().params.allowPatching || parent.getDeepPlugin().params.pyramidalNetwork || auxFixed == 0) { - sizeTxt.setEditable(false); - } - } - - /* - * Check patch size introduced by the user complies with the parameters previously - * entered to define the model - */ - public static boolean checkInputTileSize(int[] tileSize, String tensorName, Parameters params) { - DijTensor inpTensor = DijTensor.retrieveByName(tensorName, params.inputList); - String[] form = inpTensor.form.split(""); - // TODO generalise error messages for several input images - // Check that the input fulfils the conditions - for (int i = 0; i < tileSize.length; i ++) { - int step = inpTensor.step[i]; - int min = inpTensor.minimum_size[i]; - int pp = tileSize[i]; - if (step == 0 && min != pp) { - IJ.error(" INCORRECT INPUT TILE SIZE \n" - + "The size for dimension " + form[i] + " should be\n" - + "equal to " + min + " and it is instead set to " + pp); - return false; - } else if (params.allowPatching && step != 0 && (pp - min) % step != 0) { - double n = Math.floor(((double)(pp - min)) / ((double) step)); - double sugest = n * step + min; - IJ.error(" INCORRECT INPUT TILE SIZE \n" - + "Every dimension of the tile size introduced must\n" - + "be the result of:\n" - + " minimum_size + step_size x N, where N is any\n" - + "posititive integer." - + "This condition is not fulfiled at dimension " + form[i] + "(" + pp + ").\n" - + "The immediately smaller value that fulfils the\n" - + "necessary condition is " + sugest); - return false; - } else if (params.allowPatching && pp <= 0) { - IJ.error(" INCORRECT INPUT TILE SIZE \n" - + "Every dimension of the tile size introduced" - + "must be strictly bigger than 0.\n" - + "Tile size at '" + form[i] + "' is " + pp); - return false; - } - } - - // Now check that the output makes sense - if (params.pyramidalNetwork) - return true; - for (DijTensor outTensor : params.outputList) { - if (!outTensor.tensorType.contains("image")) - continue; - String[] outForm = outTensor.form.split(""); - int[] outShape = outTensor.tensor_shape; - int[] halo = outTensor.halo; - float[] offset = outTensor.offset; - float[] scale = outTensor.scale; - for ( int i = 0; i < outForm.length; i ++) { - // Find which dimension corresponds to the current output dimension - int ind = inpTensor.form.indexOf(outForm[i]); - if (ind == -1) - continue; - int outSize = 0; - // Check that the input to the model is not automatically calculated by - // the plugin. If it is, we cannot make sure anything. - if (params.allowPatching || inpTensor.step[ind] == 0) - outSize = (int) (Math.round(tileSize[ind] * scale[i]) + 2 * offset[i]); - if ((params.allowPatching || inpTensor.step[ind] == 0) && outShape[i] != -1 && outSize != outShape[i]) { - // Check that with the given parameters, the input size gives the - // output size specified by the model - IJ.error(" INCORRECT INPUT TILE SIZE \n" - + "The output size for this model at dimension '" + outForm[i]+ "'\n" - + "is specified to be " + outShape[i] + ". Applying the\n" - + "scaling and offset specified for dimension '" + outForm[i]+ "'\n" - + "considering an input size of " + tileSize[ind] + " yields an\n" - + "incorrect output size of " + outSize + ". Please, correct these parameters."); - return false; - } else if ((params.allowPatching || inpTensor.step[ind] == 0) && outSize <= 0){ - // Check that taking into account halo and offset - // the output produced is bigger than 0 - IJ.error(" INCORRECT INPUT TILE SIZE \n" - + "Applying the scaling and offset for\n" - + "output '" + outTensor.name + "' at dimension '" + outForm[i] + "' the\n" - + "resulting output size was " + outSize + " which is\n" - + "smaller than 0. The output size cannot\n" - + "be negative. Please, correct these parameters."); - return false; - } else if (2 * halo[i] > outSize) { - // The size of the halo is too big for the chosen tile size - IJ.error(" INCORRECT INPUT TILE SIZE \n" - + "Applying the scaling, offset and halo for\n" - + "output '" + outTensor.name + "' at dimension '" + outForm[i] + "' the\n" - + "resulting output size was " + (outSize - 2 * halo[i]) + " which is\n" - + "smaller than 0. The output size cannot\n" - + "be negative. Please, correct these parameters."); - return false; - } - } - } - return true; - } - - @Override - /* - * Update the JComboBox list when it is clicked - */ - public void mouseClicked(MouseEvent e) { - // Check for clicks on the arrow - for (int i = 0; i < btnList.size(); i ++) { - if (e.getSource() == btnList.get(i)) { - String[] titlesList = WindowManager.getImageTitles(); - - cmbList.get(0).removeActionListener(this); - cmbList.get(0).removeAllItems(); - // Update the list of options provided by each of - // the images - for (int j = 0; j < cmbList.size(); j ++) { - //cmbList.get(j).removeAllItems(); - if (titlesList.length != 0) { - for (String title : titlesList) - cmbList.get(j).addItem(title); - bnTest.setEnabled(true); - } else { - cmbList.get(j).addItem("No image"); - bnTest.setEnabled(false); - } - } - cmbList.get(0).addActionListener(this); - - break; - } - } - - // Check for clicks on the text field - for (int i = 0; i < cmbList.size(); i ++) { - if (e.getSource() == cmbList.get(i)) { - String[] titlesList = WindowManager.getImageTitles(); - - cmbList.get(0).removeActionListener(this); - cmbList.get(0).removeAllItems(); - // Update the list of options provided by each of - // the images - for (int j = 0; j < cmbList.size(); j ++) { - //cmbList.get(j).removeAllItems(); - if (titlesList.length != 0) { - for (String title : titlesList) - cmbList.get(j).addItem(title); - bnTest.setEnabled(true); - } else { - cmbList.get(j).addItem("No image"); - bnTest.setEnabled(false); - } - } - cmbList.get(0).addActionListener(this); - - break; - } - } - - } - - @Override - public void mousePressed(MouseEvent e) { - // Not necessary for our use case - - } - - @Override - public void mouseReleased(MouseEvent e) { - // Not necessary for our use case - - } - - @Override - public void mouseEntered(MouseEvent e) { - // Not necessary for our use case - - } - - @Override - public void mouseExited(MouseEvent e) { - // Not necessary for our use case - - } -} diff --git a/src/main/java/deepimagej/stamp/TfSaveStamp.java b/src/main/java/deepimagej/stamp/TfSaveStamp.java deleted file mode 100755 index 9e844b1e..00000000 --- a/src/main/java/deepimagej/stamp/TfSaveStamp.java +++ /dev/null @@ -1,457 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Color; -import java.awt.Font; -import java.awt.Frame; -import java.awt.GridLayout; -import java.awt.datatransfer.DataFlavor; -import java.awt.datatransfer.Transferable; -import java.awt.datatransfer.UnsupportedFlavorException; -import java.awt.dnd.DnDConstants; -import java.awt.dnd.DropTarget; -import java.awt.dnd.DropTargetDropEvent; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.io.IOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; - -import javax.swing.BorderFactory; -import javax.swing.JButton; -import javax.swing.JFileChooser; -import javax.swing.JFrame; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextField; - -import org.jetbrains.bio.npy.NpyFile; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.DeepImageJ; -import deepimagej.ImagePlus2Tensor; -import deepimagej.Parameters; -import deepimagej.Table2Tensor; -import deepimagej.components.HTMLPane; -import deepimagej.tools.FileTools; -import deepimagej.tools.YAMLUtils; -import ij.IJ; -import ij.ImagePlus; -import ij.WindowManager; -import ij.measure.ResultsTable; -import ij.text.TextWindow; - -public class TfSaveStamp extends AbstractStamp implements ActionListener, Runnable { - - private JTextField txt = new JTextField(IJ.getDirectory("imagej") + File.separator + "models" + File.separator); - private JButton bnBrowse = new JButton("Browse"); - private JButton bnSave = new JButton("Save Bundled Model"); - private HTMLPane pane; - - public TfSaveStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - public void buildPanel() { - pane = new HTMLPane(Constants.width, 320); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "Saving Bundled Model"); - JScrollPane infoPane = new JScrollPane(pane); - //infoPane.setPreferredSize(new Dimension(Constants.width, pane.getPreferredSize().height)); - DeepImageJ dp = parent.getDeepPlugin(); - - if (dp != null) - if (dp.params != null) - if (dp.params.path2Model != null) - txt.setText(dp.params.path2Model); - txt.setFont(new Font("Arial", Font.BOLD, 14)); - txt.setForeground(Color.red); - txt.setText(IJ.getDirectory("imagej") + File.separator + "models" + File.separator); - JPanel load = new JPanel(new BorderLayout()); - load.setBorder(BorderFactory.createEtchedBorder()); - load.add(txt, BorderLayout.CENTER); - load.add(bnBrowse, BorderLayout.EAST); - - JPanel pn = new JPanel(new BorderLayout()); - pn.add(load, BorderLayout.NORTH); - pn.add(infoPane, BorderLayout.CENTER); - JPanel pnButtons = new JPanel(new GridLayout(2, 1)); - pnButtons.add(bnSave); - pn.add(bnSave, BorderLayout.SOUTH); - panel.add(pn); - - bnSave.addActionListener(this); - txt.setDropTarget(new LocalDropTarget()); - load.setDropTarget(new LocalDropTarget()); - bnBrowse.addActionListener(this); - } - - @Override - public void init() { - } - - @Override - public boolean finish() { - return true; - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == bnBrowse) { - browse(); - } - if (e.getSource() == bnSave) { - save(); - } - } - - private void browse() { - JFileChooser chooser = new JFileChooser(txt.getText()); - chooser.setFileSelectionMode(JFileChooser.DIRECTORIES_ONLY); - chooser.setDialogTitle("Select model"); - int ret = chooser.showSaveDialog(new JFrame()); - if (ret == JFileChooser.APPROVE_OPTION) { - txt.setText(chooser.getSelectedFile().getAbsolutePath()); - txt.setCaretPosition(1); - } - } - - public void save() { - Thread thread = new Thread(this); - thread.setPriority(Thread.MIN_PRIORITY); - thread.start(); - - } - @Override - public void run() { - DeepImageJ dp = parent.getDeepPlugin(); - Parameters params = dp.params; - params.saveDir = txt.getText() + File.separator; - params.saveDir = params.saveDir.replace(File.separator + File.separator, File.separator); - File dir = new File(params.saveDir); - - if (dir.exists() && dir.isDirectory()) { - pane.append("p", "Path introduced corresponded to an already existing directory."); - pane.append("p", "Model not saved"); - IJ.error("Directory: \n" + dir.getAbsolutePath() + "\n already exists. Please introduce other name."); - return; - } - - if (!dir.exists()) { - dir.mkdir(); - pane.append("p", "Making directory: " + params.saveDir); - } - - dir = new File(params.saveDir); - - - // Save the model architecture - - params.biozoo = true; - try { - String zipName = "tensorflow_saved_model_bundle.zip"; - pane.append("p", "Writting zip file..."); - FileTools.zip(new String[]{params.path2Model + File.separator + "variables", params.path2Model + File.separator + "saved_model.pb"}, params.saveDir + File.separator + zipName); - pane.append("p", "Tensorflow Bioimage Zoo model: saved"); - } - catch (Exception e) { - e.printStackTrace(); - pane.append("p", "Error zipping the varaibles folder and saved_model.pb"); - pane.append("p", "Zipped Tensorflow model: not saved"); - } - - // List with processing files saved - ArrayList saved = new ArrayList(); - // Save preprocessing - if (params.firstPreprocessing != null) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.firstPreprocessing).getName()); - FileTools.copyFile(new File(params.firstPreprocessing), destFile); - pane.append("p", "First preprocessing: saved"); - saved.add(params.firstPreprocessing); - } - catch (Exception e) { - pane.append("p", "First preprocessing: not saved"); - } - } - if (params.secondPreprocessing != null && !saved.contains(params.secondPreprocessing)) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.secondPreprocessing).getName()); - FileTools.copyFile(new File(params.secondPreprocessing), destFile); - pane.append("p", "Second preprocessing: saved"); - saved.add(params.secondPreprocessing); - } - catch (Exception e) { - pane.append("p", "Second preprocessing: not saved"); - } - } else if (params.secondPreprocessing != null) { - pane.append("p", "Second preprocessing: saved"); - } - - // Save postprocessing - if (params.firstPostprocessing != null && !saved.contains(params.firstPostprocessing)) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.firstPostprocessing).getName()); - FileTools.copyFile(new File(params.firstPostprocessing), destFile); - pane.append("p", "First postprocessing: saved"); - } - catch (Exception e) { - pane.append("p", "First postprocessing: not saved"); - } - } else if (params.firstPostprocessing != null) { - pane.append("p", "First postprocessing: saved"); - } - if (params.secondPostprocessing != null && !saved.contains(params.secondPostprocessing)) { - try { - File destFile = new File(params.saveDir + File.separator + new File(params.secondPostprocessing).getName()); - FileTools.copyFile(new File(params.secondPostprocessing), destFile); - pane.append("p", "Second postprocessing: saved"); - } - catch (Exception e) { - pane.append("p", "Second postprocessing: not saved"); - } - } else if (params.secondPostprocessing != null) { - pane.append("p", "Second postprocessing: saved"); - } - - // Save input image - try { - if (params.testImageBackup != null) { - // Get name with no extension - String title = getTitleWithoutExtension(params.testImageBackup.getTitle().substring(4)); - IJ.saveAsTiff(params.testImageBackup, params.saveDir + File.separator + title + ".tif"); - pane.append("p", title + ".tif" + ": saved"); - boolean npySaved = saveNpyFile(params.testImageBackup, "XYCZN", params.saveDir + File.separator + title + ".npy"); - if (npySaved) - pane.append("p", title + ".npy" + ": saved"); - else - pane.append("p", title + ".npy: not saved"); - params.testImageBackup.setTitle("DUP_" + params.testImageBackup.getTitle()); - } else { - throw new Exception(); - } - } - catch(Exception ex) { - pane.append("p", "exampleImage.tif: not saved"); - pane.append("p", "exampleImage.npy: not saved"); - } - - // Save output images and tables (tables as saved as csv) - for (HashMap output : params.savedOutputs) { - String name = output.get("name"); - String nameNoExtension= getTitleWithoutExtension(name); - try { - if (output.get("type").contains("image")) { - ImagePlus im = WindowManager.getImage(name); - IJ.saveAsTiff(im, params.saveDir + File.separator + nameNoExtension + ".tif"); - im.setTitle(name); - pane.append("p", nameNoExtension + ".tif" + ": saved"); - boolean npySaved = saveNpyFile(im, "XYCZB", params.saveDir + File.separator + nameNoExtension + ".npy"); - if (npySaved) - pane.append("p", nameNoExtension + ".npy" + ": saved"); - else - pane.append("p", nameNoExtension + ".npy: not saved"); - } else if (output.get("type").contains("ResultsTable")){ - Frame f = WindowManager.getFrame(name); - if (f!=null && (f instanceof TextWindow)) { - ResultsTable rt = ((TextWindow)f).getResultsTable(); - rt.save(params.saveDir + File.separator + nameNoExtension + ".csv"); - pane.append("p", nameNoExtension + ".csv" + ": saved"); - boolean npySaved = saveNpyFile(rt, params.saveDir + File.separator + nameNoExtension + ".npy", "RC"); - if (npySaved) - pane.append("p", nameNoExtension + ".npy" + ": saved"); - else - pane.append("p", nameNoExtension + ".npy: not saved"); - } else { - throw new Exception(); } - } - } - catch(Exception ex) { - pane.append("p", nameNoExtension + ".tif: not saved"); - pane.append("p", nameNoExtension + ".npy: not saved"); - } - } - - // Save yaml - try { - YAMLUtils.writeYaml(dp); - pane.append("p", "rdf.yaml: saved"); - } - catch(IOException ex) { - pane.append("p", "rdf.yaml: not saved"); - IJ.error("Model file was locked or does not exist anymore."); - } - catch(Exception ex) { - ex.printStackTrace(); - pane.append("p", "rdf.yaml: not saved"); - } - - // Finally save external dependencies - boolean saveDeps = saveExternalDependencies(params); - if (saveDeps && params.attachments.size() > 0) { - pane.append("p", "External dependencies: saved"); - } else if (!saveDeps && params.attachments.size() > 0) { - pane.append("p", "External dependencies: not saved"); - } - - pane.append("p", "Done!!"); - - } - - /** - * Save jar dependency files indicated by the developer and saved at params.attachments. - * Saves all types of files but '.jar' o '.class' files - * @param params - * @return true if saving was successful or false otherwise - */ - public static boolean saveExternalDependencies(Parameters params) { - boolean saved = true; - ArrayList savedFiles = new ArrayList(); - String errMsg = "DeepImageJ unable to save:\n"; - for (String dep : params.attachments) { - if (savedFiles.contains(new File(dep).getName()) || (new File(dep).getName()).endsWith(".jar") || (new File(dep).getName()).endsWith(".class")) - continue; - File destFile = new File(params.saveDir + File.separator + new File(dep).getName()); - try { - FileTools.copyFile(new File(dep), destFile); - savedFiles.add(new File(dep).getName()); - } catch (IOException e) { - saved = false; - errMsg += " - " + new File(dep).getName(); - e.printStackTrace(); - } - } - if (!saved) { - IJ.error(errMsg); - } - return saved; - } - - /* - * Gets the image title without the extension - */ - public static String getTitleWithoutExtension(String title) { - int lastDot = title.lastIndexOf("."); - if (lastDot == -1) - return title; - return title.substring(0, lastDot); - } - - public static boolean saveNpyFile(ImagePlus im, String form, String name) { - Path path = Paths.get(name); - String imTitle = name.substring(name.lastIndexOf(File.separator) + 1, name.lastIndexOf(".")); - long[] imShapeLong = ImagePlus2Tensor.getTensorShape(im, form); - int[] imShape = new int[imShapeLong.length]; - // If the number of pixels is too big let the user know that they might prefer not saving - // the results in npy format - String msg = "Do you want to save the image '" + imTitle + "' in .npy format.\n" - + "Saving it might take too long. Do you want to continue?"; - boolean accept = IJ.showMessageWithCancel("Cancel .npy file save", msg); - if (!accept) - return false; - - float[] imArray = ImagePlus2Tensor.implus2IntArray(im, form); - NpyFile.write(path, imArray, imShape); - return true; - } - - public static boolean saveNpyFile(ResultsTable table, String name, String form) { - Path path = Paths.get(name); - int[] shape = Table2Tensor.getTableShape(form, table); - // Convert the array into long - long[] shapeLong = new long[shape.length]; - // If the number of pixels is too big let the user know that they might prefer not saving - // the results in npy format - String msg = "Do you want to save the table '" + table.getTitle() + "' in .npy format.\n" - + "Saving it might take too long. Do you want to continue?"; - boolean accept = IJ.showMessageWithCancel("Cancel .npy file save", msg); - if (!accept) - return false; - // Get the array - float[] flatRt = Table2Tensor.tableToFlatArray(table, form, shapeLong); - NpyFile.write(path, flatRt, shape); - return true; - } - - public class LocalDropTarget extends DropTarget { - - @Override - public void drop(DropTargetDropEvent e) { - e.acceptDrop(DnDConstants.ACTION_COPY); - e.getTransferable().getTransferDataFlavors(); - Transferable transferable = e.getTransferable(); - DataFlavor[] flavors = transferable.getTransferDataFlavors(); - for (DataFlavor flavor : flavors) { - if (flavor.isFlavorJavaFileListType()) { - try { - List files = (List) transferable.getTransferData(flavor); - for (File file : files) { - txt.setText(file.getAbsolutePath()); - txt.setCaretPosition(1); - } - } - catch (UnsupportedFlavorException ex) { - ex.printStackTrace(); - } - catch (IOException ex) { - ex.printStackTrace(); - } - } - } - e.dropComplete(true); - super.drop(e); - } - } - - -} diff --git a/src/main/java/deepimagej/stamp/WelcomeStamp.java b/src/main/java/deepimagej/stamp/WelcomeStamp.java deleted file mode 100755 index 63a06921..00000000 --- a/src/main/java/deepimagej/stamp/WelcomeStamp.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.stamp; - -import java.awt.BorderLayout; -import java.awt.Color; -import java.awt.Dimension; -import java.awt.Font; -import java.awt.datatransfer.DataFlavor; -import java.awt.datatransfer.Transferable; -import java.awt.datatransfer.UnsupportedFlavorException; -import java.awt.dnd.DnDConstants; -import java.awt.dnd.DropTarget; -import java.awt.dnd.DropTargetDropEvent; -import java.awt.event.ActionEvent; -import java.awt.event.ActionListener; -import java.io.File; -import java.io.IOException; -import java.util.List; -import java.util.Vector; - -import javax.swing.BorderFactory; -import javax.swing.JButton; -import javax.swing.JFileChooser; -import javax.swing.JFrame; -import javax.swing.JPanel; -import javax.swing.JTextField; - -import deepimagej.BuildDialog; -import deepimagej.Constants; -import deepimagej.DeepImageJ; -import deepimagej.components.HTMLPane; -import ij.IJ; -import ij.gui.GenericDialog; - -public class WelcomeStamp extends AbstractStamp implements ActionListener { - - private JTextField txt = new JTextField("Drop zone TensorFlow model (protobuf)"); - private JButton bnBrowse = new JButton("Browse"); - - public WelcomeStamp(BuildDialog parent) { - super(parent); - buildPanel(); - } - - public void buildPanel() { - HTMLPane pane = new HTMLPane(Constants.width, 320); - pane.setBorder(BorderFactory.createEtchedBorder()); - pane.append("h2", "Building Bundled Model"); - pane.append("p", - "This wizard allows to create a bundled model for DeepImageJ in 10 steps. " - + "The first step will consist to load the pretrained TensorFlow or Pytorch (see documentation) model. " - + "At the end, the DeepImageJ Bundled Model is saved in a directory." - + "Then, it can be easily used by the plugin 'DeepImageJ Run'"); - - pane.append("p", "Before to start the building, the following material is required:
        "); - pane.append("li", - "

        A pretrained TensorFlow model version 1.15 or lower. " + "This pretrained model has to be stored in a TensorFlow SavedModel file (save_model.pb and variables)

        "); - pane.append("li", - "

        A pretrained Pytorch Torchscipt model version 1.7.0 or lower. " + "This pretrained model has to be stored in a folder. The path to the folder is what needsto be provided.

        "); - pane.append("li", "

        General information of the pretrained model

        "); - pane.append("li", "

        Knowledge of tensor organization and the tiling strategy

        "); - pane.append("li", "

        Macro or java file of preprocessing and postprocessing

        "); - pane.append("li", "

        A test image

        "); - pane.append("
      "); - pane.append("p", "More information: deepimagej.github.io/deepimagej"); - pane.append("p", "Reference: E. Gómez de Mariscal and C. García-López-de-Haro et al. DeepImageJ:" - + " A user-friendly plugin to run\n" + - "deep learning models in ImageJ. Nat Methods 18, 1192–1195 (2021)"); - pane.append("
      "); - pane.append("p", - "© 2019 - 2022. Biomedical Imaging Group, Ecole Polytechnique Fédérale de Lausanne (EPFL), Switzerland " - + "and Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain."); - - txt.setFont(new Font("Arial", Font.BOLD, 14)); - txt.setForeground(Color.red); - txt.setPreferredSize(new Dimension(Constants.width, 25)); - JPanel load = new JPanel(new BorderLayout()); - load.setBorder(BorderFactory.createEtchedBorder()); - load.add(txt, BorderLayout.CENTER); - load.add(bnBrowse, BorderLayout.EAST); - - JPanel pn = new JPanel(new BorderLayout()); - pn.add(pane.getPane(), BorderLayout.CENTER); - pn.add(load, BorderLayout.SOUTH); - panel.add(pn); - - txt.setDropTarget(new LocalDropTarget()); - load.setDropTarget(new LocalDropTarget()); - bnBrowse.addActionListener(this); - } - - @Override - public void init() { - } - - @Override - public boolean finish() { - String filename = txt.getText(); - File file = new File(filename); - if (!file.exists()) { - IJ.error("This directory " + filename + " doesn't exist"); - return false; - } - - // TODO for the moment only allow folder models - if (!file.isDirectory()) { - IJ.error("This file " + filename + " does not correspond to a Pytorch or Tensorflow model."); - return false; - } - - File pb = new File(filename + File.separator + "saved_model.pb"); - if (!pb.exists() && !DeepImageJ.isTherePytorch(file)) { - IJ.error("This directory " + filename + " is not a protobuf model (no saved_model.pb)" - + "\nmodel (no saved_model.pb) neither a Pytorch model"); - return false; - } - return true; - } - - public class LocalDropTarget extends DropTarget { - - @Override - public void drop(DropTargetDropEvent e) { - e.acceptDrop(DnDConstants.ACTION_COPY); - e.getTransferable().getTransferDataFlavors(); - Transferable transferable = e.getTransferable(); - DataFlavor[] flavors = transferable.getTransferDataFlavors(); - for (DataFlavor flavor : flavors) { - if (flavor.isFlavorJavaFileListType()) { - try { - List files = (List) transferable.getTransferData(flavor); - for (File file : files) { - txt.setText(file.getAbsolutePath()); - txt.setCaretPosition(1); - } - } - catch (UnsupportedFlavorException ex) { - ex.printStackTrace(); - } - catch (IOException ex) { - ex.printStackTrace(); - } - } - } - e.dropComplete(true); - super.drop(e); - } - } - - @Override - public void actionPerformed(ActionEvent e) { - if (e.getSource() == bnBrowse) { - browse(); - } - } - - private void browse() { - JFileChooser chooser = new JFileChooser(txt.getText()); - chooser.setFileSelectionMode(JFileChooser.DIRECTORIES_ONLY); - chooser.setDialogTitle("Select model"); - int ret = chooser.showOpenDialog(new JFrame()); - if (ret == JFileChooser.APPROVE_OPTION) { - txt.setText(chooser.getSelectedFile().getAbsolutePath()); - txt.setCaretPosition(1); - } - } - - public String getModelDir() { - File file = new File(txt.getText()); - if (file.exists()) { - return file.getParent(); - } - return null; - } - public String getModelName() { - File file = new File(txt.getText()); - if (file.exists()) { - return file.getName(); - } - return null; - - } -} diff --git a/src/main/java/deepimagej/tools/ModelLoader.java b/src/main/java/deepimagej/tools/ModelLoader.java index 8ac5877b..6a948ed1 100755 --- a/src/main/java/deepimagej/tools/ModelLoader.java +++ b/src/main/java/deepimagej/tools/ModelLoader.java @@ -48,13 +48,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.concurrent.Callable; -import java.util.zip.ZipException; + import deepimagej.DeepImageJ; import deepimagej.DeepLearningModel; import deepimagej.RunnerProgress; -import deepimagej.stamp.LoadPytorchStamp; import ij.IJ; +import io.bioimage.modelrunner.model.Model; public class ModelLoader implements Callable{ private DeepImageJ dp; @@ -62,15 +62,15 @@ public class ModelLoader implements Callable{ private boolean gpu; private boolean cuda; private boolean show; - private boolean isFiji; + private Model model; - public ModelLoader(DeepImageJ dp, RunnerProgress rp, boolean gpu, boolean cuda, boolean show, boolean isFiji) { - this.dp = dp; + public ModelLoader(DeepImageJ dp, Model model, RunnerProgress rp, boolean gpu, boolean cuda, boolean show) { this.rp = rp; this.gpu = gpu; this.cuda = cuda; this.show = show; - this.isFiji = isFiji; + this.model = model; + this.dp = dp; } @Override @@ -114,13 +114,8 @@ public Boolean call() { // while executing the task if (rp != null) rp.allowStopping(false); - boolean ret = false; - if (dp.params.framework.equals("tensorflow")) { - ret = dp.loadTfModel(true); - } else if (dp.params.framework.equals("pytorch")) { - String ptWeightsPath = dp.getPath() + File.separatorChar + dp.ptName; - ret = dp.loadPtModel(ptWeightsPath, isFiji); - } + dp.setModel(model); + boolean ret = dp.loadModel(); if (ret == false && dp.params.framework.equals("tensorflow")) { IJ.error("Error loading " + dp.getName() + "\nTry using another Tensorflow version."); @@ -160,25 +155,6 @@ public Boolean call() { rp.setGPU("gpu"); } } - - if (dp.params.framework.toLowerCase().equals("pytorch")) { - String ptNativeFileName = LoadPytorchStamp.getNativeLbraryFile(); - File libFile = new File(ptNativeFileName); - if (!libFile.exists()) { - rp.setGPU(ptNativeFileName); - dp.params.pytorchVersion = DeepLearningModel.getPytorchVersion(); - } else { - // Get the Pytorch version being used reading the fist part of the lib folder - String parentFolderName = libFile.getParentFile().getName(); - dp.params.pytorchVersion = parentFolderName.substring(0, parentFolderName.indexOf("-")); - if (rp != null && libFile.getName().toLowerCase().contains("cpu")) { - rp.setGPU("cpu"); - } else if (rp != null){ - rp.setGPU("gpu"); - } - } - } - return true; } diff --git a/src/main/java/deepimagej/tools/StartTensorflowService.java b/src/main/java/deepimagej/tools/StartTensorflowService.java deleted file mode 100755 index 1b98bda8..00000000 --- a/src/main/java/deepimagej/tools/StartTensorflowService.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * DeepImageJ - * - * https://deepimagej.github.io/deepimagej/ - * - * Reference: DeepImageJ: A user-friendly environment to run deep learning models in ImageJ - * E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, W. Ouyang, L. Donati, M. Unser, E. Lundberg, A. Munoz-Barrutia, D. Sage. - * Submitted 2021. - * Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain - * Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland - * Science for Life Laboratory, School of Engineering Sciences in Chemistry, Biotechnology and Health, KTH - Royal Institute of Technology, Sweden - * - * Authors: Carlos Garcia-Lopez-de-Haro and Estibaliz Gomez-de-Mariscal - * - */ - -/* - * BSD 2-Clause License - * - * Copyright (c) 2019-2021, DeepImageJ - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -package deepimagej.tools; - -import org.scijava.Context; -import org.scijava.service.SciJavaService; - -import ij.IJ; -import net.imagej.ImageJService; -import net.imagej.tensorflow.TensorFlowService; - -public class StartTensorflowService { - - private static TensorFlowService tfService; - private static boolean newContext = false; - - - /* - * Try to load tf using IMageJ-Tensorflow manager. If it fails - * notify that we are on IJ! and that the tf library will be loaded - * from the jars library using libtensorflow.jar and libtensorflow_jni.jar - */ - public static String loadTfLibrary() { - Context ctx = (Context) IJ.runPlugIn("org.scijava.Context", ""); - if (ctx == null) { - ctx = new Context(ImageJService.class, SciJavaService.class); - newContext = true; - } - tfService = ctx.service(TensorFlowService.class); - if (!tfService.getStatus().isLoaded()) { - tfService.initialize(); - tfService.loadLibrary(); - if (tfService.getStatus().isLoaded()) { - return tfService.getStatus().getInfo(); - } else { - IJ.log(tfService.getStatus().getInfo()); - return ""; - } - } - return tfService.getStatus().getInfo(); - } - - public static TensorFlowService getTfService() { - return tfService; - } - - public static void closeTfService() { - tfService.dispose(); - System.out.println("[DEBUG] Close Tensorflow services"); - } -} diff --git a/src/main/java/deepimagej/tools/SystemUsage.java b/src/main/java/deepimagej/tools/SystemUsage.java index acf93896..a519d907 100755 --- a/src/main/java/deepimagej/tools/SystemUsage.java +++ b/src/main/java/deepimagej/tools/SystemUsage.java @@ -53,6 +53,9 @@ import java.lang.management.OperatingSystemMXBean; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import ij.IJ; @@ -61,6 +64,19 @@ public class SystemUsage { private static String checkFiji = null; private static boolean fiji = false; + /** + * HashMap containing the versions of CUDA compatible with each Pytorch versions + */ + public static final Map> MAP_PYTORCH_CUDA = createCudaCompatiblePytorchMap(); + /** + * HashMap containing the versions of CUDA compatible with each Tensorflow versions + */ + public static final Map> MAP_TF_CUDA = createCudaCompatibleTensorflowMap(); + /** + * HashMap containing the versions of CUDA compatible with each Onnx versions + */ + public static final Map> MAP_ONNX_CUDA = createCudaCompatibleOnnxMap(); + public static String getMemoryMB() { MemoryMXBean mem = ManagementFactory.getMemoryMXBean(); double c = 1.0 / (1024.0 * 1024.0); @@ -631,4 +647,55 @@ public static boolean checkFiji() { return fiji; } } + + private static HashMap> createCudaCompatiblePytorchMap(){ + HashMap> map = new HashMap>(); + map.put("1.7.1", Arrays.asList(new String[]{"CUDA 10.1", "CUDA 10.2", "CUDA 11.0"})); + map.put("1.8.1", Arrays.asList(new String[]{"CUDA 10.2", "CUDA 11.1"})); + map.put("1.9.0", Arrays.asList(new String[]{"CUDA 10.2", "CUDA 11.1"})); + map.put("1.9.1", Arrays.asList(new String[]{"CUDA 10.2", "CUDA 11.1"})); + map.put("1.10.0", Arrays.asList(new String[]{"CUDA 10.2", "CUDA 11.3"})); + map.put("1.11.0", Arrays.asList(new String[]{"CUDA 10.2", "CUDA 11.3"})); + map.put("1.12.1", Arrays.asList(new String[]{"CUDA 10.2", "CUDA 11.6"})); + map.put("1.13.0", Arrays.asList(new String[]{"CUDA 11.7"})); + return map; + } + + private static HashMap> createCudaCompatibleTensorflowMap(){ + HashMap> map = new HashMap>(); + map.put("1.12.0", Arrays.asList(new String[]{"CUDA 9.0"})); + map.put("1.13.0", Arrays.asList(new String[]{"CUDA 10.0"})); + map.put("1.13.1", Arrays.asList(new String[]{"CUDA 10.0"})); + map.put("1.14.0", Arrays.asList(new String[]{"CUDA 10.0"})); + map.put("1.15.0", Arrays.asList(new String[]{"CUDA 10.0"})); + map.put("2.0.0", Arrays.asList(new String[]{"CUDA 10.0"})); + map.put("2.1.0", Arrays.asList(new String[]{"CUDA 10.1"})); + map.put("2.2.0", Arrays.asList(new String[]{"CUDA 10.1"})); + map.put("2.3.0", Arrays.asList(new String[]{"CUDA 10.1"})); + map.put("2.3.1", Arrays.asList(new String[]{"CUDA 10.1"})); + map.put("2.4.0", Arrays.asList(new String[]{"CUDA 11.0"})); + map.put("2.4.1", Arrays.asList(new String[]{"CUDA 11.0"})); + map.put("2.5.0", Arrays.asList(new String[]{"CUDA 11.2"})); + map.put("2.6.0", Arrays.asList(new String[]{"CUDA 11.2"})); + map.put("2.7.0", Arrays.asList(new String[]{"CUDA 11.2"})); + map.put("2.7.1", Arrays.asList(new String[]{"CUDA 11.2"})); + map.put("2.7.4", Arrays.asList(new String[]{"CUDA 11.2"})); + return map; + } + + private static HashMap> createCudaCompatibleOnnxMap(){ + HashMap> map = new HashMap>(); + map.put("8", Arrays.asList(new String[]{"CUDA 10.1"})); + map.put("9", Arrays.asList(new String[]{"CUDA 10.1"})); + map.put("10", Arrays.asList(new String[]{"CUDA 10.2"})); + map.put("11", Arrays.asList(new String[]{"CUDA 10.2"})); + map.put("12", Arrays.asList(new String[]{"CUDA 11.0.3"})); + map.put("13", Arrays.asList(new String[]{"CUDA 11.0.3"})); + map.put("14", Arrays.asList(new String[]{"CUDA 11.4"})); + map.put("15", Arrays.asList(new String[]{"CUDA 11.4"})); + map.put("16", Arrays.asList(new String[]{"CUDA 11.4"})); + map.put("17", Arrays.asList(new String[]{"CUDA 11.4"})); + map.put("18", Arrays.asList(new String[]{"CUDA 11.6"})); + return map; + } } diff --git a/src/main/java/deepimagej/tools/YAMLUtils.java b/src/main/java/deepimagej/tools/YAMLUtils.java index f2e53ceb..d53d9c53 100755 --- a/src/main/java/deepimagej/tools/YAMLUtils.java +++ b/src/main/java/deepimagej/tools/YAMLUtils.java @@ -71,332 +71,10 @@ import deepimagej.DeepImageJ; import deepimagej.Parameters; import deepimagej.DeepLearningModel; -import deepimagej.stamp.TfSaveStamp; import ij.IJ; public class YAMLUtils { - public static void writeYaml(DeepImageJ dp) throws NoSuchAlgorithmException, IOException { - Parameters params = dp.params; - - Map data = new LinkedHashMap<>(); - - List> modelInputMapsList = new ArrayList<>(); - List> inputTestInfoList = new ArrayList<>(); - for (DijTensor inp : params.inputList) { - if (inp.tensorType.contains("image")) { - // Create dictionary for each image input - Map inputTensorMap = new LinkedHashMap<>(); - inputTensorMap.put("name", inp.name); - inputTensorMap.put("axes", inp.form.toLowerCase()); - - inputTensorMap.put("data_type", "float32"); - inputTensorMap.put("data_range", Arrays.toString(inp.dataRange)); - if (params.fixedInput) { - inputTensorMap.put("shape", Arrays.toString(inp.recommended_patch)); - } else if (!params.fixedInput) { - Map shape = new LinkedHashMap<>(); - shape.put("min", Arrays.toString(inp.minimum_size)); - int[] aux = new int[inp.minimum_size.length]; - for(int i = 0; i < aux.length; i ++) {aux[i] += inp.step[i];} - shape.put("step", Arrays.toString(aux)); - inputTensorMap.put("shape", shape); - } - inputTensorMap.put("preprocessing", null); - modelInputMapsList.add(inputTensorMap); - - // Now write the test data info - Map inputTestInfo = new LinkedHashMap<>(); - if (params.testImageBackup != null) - inputTestInfo.put("name", params.testImageBackup.getTitle().substring(4)); - else - inputTestInfo.put("name", null); - inputTestInfo.put("size", inp.inputTestSize); - Map pixelSize = new LinkedHashMap<>(); - pixelSize.put("x", inp.inputPixelSizeX); - pixelSize.put("y", inp.inputPixelSizeY); - pixelSize.put("z", inp.inputPixelSizeZ); - inputTestInfo.put("pixel_size", pixelSize); - inputTestInfoList.add(inputTestInfo); - } - } - - // Test output metadata - List> modelOutputMapsList = new ArrayList<>(); - for (DijTensor out : params.outputList) { - // Create dictionary for each input - Map outputTensorMap = getOutput(out, params.pyramidalNetwork, params.allowPatching); - modelOutputMapsList.add(outputTensorMap); - } - - // Write the info of the outputs after postprocesing - List> outputTestInfoList = new ArrayList<>(); - for (HashMap out : params.savedOutputs) { - - Map outputTestInfo = new LinkedHashMap<>(); - outputTestInfo.put("name", out.get("name")); - outputTestInfo.put("type", out.get("type")); - outputTestInfo.put("size", out.get("size")); - outputTestInfoList.add(outputTestInfo); - } - - // Version of the yaml file - data.put("format_version", params.format_version); - // Name of the model - data.put("name", params.name); - // Short description of the model - data.put("description", params.description); - // Timestamp of when the model was created following ISO 8601 - String thisMoment = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSSSS").format(Calendar.getInstance().getTime()); - data.put("timestamp", thisMoment); - - // Citation - if (params.cite != null && params.cite.size() == 0) - params.cite = null; - data.put("cite", params.cite); - // List of authors who trained/prepared the actual model which is being saved - data.put("authors", params.author); - // Link to the documentation of the model, which contains info about - // the model such as the images used or architecture - data.put("documentation", params.documentation); - // Path to the image that will be used as the cover picture in the Bioimage model Zoo - ArrayList covers = new ArrayList(); - // TODO generalize for several input images - if (params.testImageBackup != null) - covers.add("./" + TfSaveStamp.getTitleWithoutExtension(params.testImageBackup.getTitle().substring(4)) + ".tif"); - else - covers.add(null); - for (HashMap out : params.savedOutputs) { - if (out.get("type").contains("image")) - covers.add("./" + TfSaveStamp.getTitleWithoutExtension(out.get("name")) + ".tif"); - } - data.put("covers", covers); - // Tags that will be used to look for the model in the Bioimage model Zoo - data.put("tags", params.infoTags); - // Type of license of the model - data.put("license", params.license); - // Programming language in which the model was prepared for the Bioimage model zoo - data.put("language", params.language); - // Deep Learning framework with which the model was obtained - data.put("framework", params.framework); - // Git repo where info to the model can be found - data.put("git_repo", params.git_repo); - - // Create field containing the weights format and info - Map weights = new LinkedHashMap<>(); - // Map for a specific format containing the info for the weigths of a given format - Map format_info = new LinkedHashMap<>(); - // Authors of the model. Authors of the package if this model - // is the original model - // TODO how to put it in the interface - format_info.put("authors", null); - format_info.put("parent", null); - - // TODO allow uploading models to github, zenodo or drive - if (params.framework.equals("pytorch")) { - format_info.put("source", "./weights-torchscript.pt"); - } else if (params.framework.equals("tensorflow")) { - format_info.put("source", "./tensorflow_saved_model_bundle.zip"); - } - // For Tensorflow, if upload to biozoo is selected, calculate checksum - // For Pytorch, always calculate checksum - if (params.framework.equals("pytorch")) { - format_info.put("sha256", FileTools.createSHA256(params.saveDir + File.separator + "weights-torchscript.pt")); - } else if (params.framework.equals("tensorflow") && params.biozoo) { - format_info.put("sha256", FileTools.createSHA256(params.saveDir + File.separator + "tensorflow_saved_model_bundle.zip")); - } else if (params.framework.equals("tensorflow") && !params.biozoo) { - format_info.put("sha256", null); - } - // Add the preprocessing attachments to the weights, as they are part of the model - ArrayList aux = new ArrayList(); - for (String str : params.attachments) { - aux.add(new File(str).getName()); - } - // Add information asking the developer to add plugin requirements to the attachments list - aux.add("Include here any plugin that might be required for pre- or post-processing"); - // Map for the attachments used by the model. They can be - // either files or uris. For the moment DIJ only supports files - Map modelAtachments = new LinkedHashMap<>(); - // List of Java files that need to be included to make the plugin run - modelAtachments.put("files", aux); - format_info.put("attachments", modelAtachments); - - if (params.framework.equals("pytorch")) { - weights.put("torchscript", format_info); - } else { - weights.put("tensorflow_saved_model_bundle", format_info); - } - - // Path to the test inputs - ArrayList inputExamples = new ArrayList(); - ArrayList sampleInputs = new ArrayList(); - // TODO generalize for several input images - if (params.testImageBackup != null) { - String title = params.testImageBackup.getTitle().substring(4); - sampleInputs.add("./" + TfSaveStamp.getTitleWithoutExtension(title) + ".tif"); - inputExamples.add("./" + TfSaveStamp.getTitleWithoutExtension(title) + ".npy"); - } else { - inputExamples.add(null); - sampleInputs.add(null); - } - // Path to the test outputs - ArrayList outputExamples = new ArrayList(); - ArrayList sampleOutputs = new ArrayList(); - for (HashMap out : params.savedOutputs) { - if (out.get("type").contains("image")) - sampleOutputs.add("./" + TfSaveStamp.getTitleWithoutExtension(out.get("name")) + ".tif"); - else if (out.get("type").contains("ResultsTable")) - sampleOutputs.add("./" + TfSaveStamp.getTitleWithoutExtension(out.get("name")) + ".csv"); - outputExamples.add("./" + TfSaveStamp.getTitleWithoutExtension(out.get("name")) + ".npy"); - } - - // Info relevant to DeepImageJ, see: https://github.com/bioimage-io/configuration/issues/23 - Map config = new LinkedHashMap<>(); - Map deepimagej = new LinkedHashMap<>(); - deepimagej.put("pyramidal_model", params.pyramidalNetwork); - deepimagej.put("allow_tiling", params.allowPatching); - - // TF model keys - if (params.framework.contains("tensorflow")) { - Map modelKeys = new LinkedHashMap<>(); - // Model tag - modelKeys.put("tensorflow_model_tag", DeepLearningModel.returnTfTag(params.tag)); - // Model signature definition - modelKeys.put("tensorflow_siganture_def", DeepLearningModel.returnTfSig(params.graph)); - deepimagej.put("model_keys", modelKeys); - } else if (params.framework.contains("pytorch")) { - deepimagej.put("model_keys", null); - } - - // Test metadata - Map testInformation = new LinkedHashMap<>(); - // Test input metadata - testInformation.put("inputs", inputTestInfoList); - - // Test output metadata - testInformation.put("outputs", outputTestInfoList); - - // Output size of the examples used to compose the model - testInformation.put("memory_peak", params.memoryPeak); - // Output size of the examples used to compose the model - testInformation.put("runtime", params.runtime); - // Metadata of the example used to compose the model - deepimagej.put("test_information", testInformation); - - - // Put the example inputs and outputs - data.put("test_inputs", inputExamples); - data.put("test_outputs", outputExamples); - data.put("sample_inputs", sampleInputs); - data.put("sample_outputs", sampleOutputs); - - - // TODO what attachments should go here? - ArrayList attachments = new ArrayList(); - data.put("attachments", attachments); - // Link to the folder containing the weights - data.put("weights", weights); - - data.put("inputs", modelInputMapsList); - data.put("outputs", modelOutputMapsList); - - - // Preprocessing - List> listPreprocess = new ArrayList>(); - if (params.firstPreprocessing == null) { - params.firstPreprocessing = params.secondPostprocessing; - params.secondPreprocessing = null; - } - - int c = 0; - if ((params.firstPreprocessing != null) && (params.firstPreprocessing.contains(".ijm") || params.firstPreprocessing.contains(".txt"))) { - Map preprocess = new LinkedHashMap<>(); - preprocess.put("spec", "ij.IJ::runMacroFile"); - preprocess.put("kwargs", new File(params.firstPreprocessing).getName()); - listPreprocess.add(preprocess); - } else if ((params.firstPreprocessing != null) && (params.firstPreprocessing.contains(".class") || params.firstPreprocessing.contains(".jar"))) { - String filename = new File(params.firstPreprocessing).getName(); - Map preprocess = new LinkedHashMap<>(); - preprocess.put("spec", filename + " " + params.javaPreprocessingClass.get(c ++) + "::preProcessingRoutineUsingImage"); - listPreprocess.add(preprocess); - } else if (params.firstPreprocessing == null && params.secondPreprocessing == null) { - Map preprocess = new LinkedHashMap<>(); - preprocess.put("spec", null); - listPreprocess.add(preprocess); - } - if ((params.secondPreprocessing != null) && (params.secondPreprocessing.contains(".ijm") || params.secondPreprocessing.contains(".txt"))) { - Map preprocess = new LinkedHashMap<>(); - preprocess.put("spec", "ij.IJ::runMacroFile"); - preprocess.put("kwargs", new File(params.secondPreprocessing).getName()); - listPreprocess.add(preprocess); - } else if ((params.secondPreprocessing != null) && (params.secondPreprocessing.contains(".class") || params.secondPreprocessing.contains(".jar"))) { - String filename = new File(params.secondPreprocessing).getName(); - Map preprocess = new LinkedHashMap<>(); - preprocess.put("spec", filename + " " + params.javaPreprocessingClass.get(c ++) + "::preProcessingRoutineUsingImage"); - listPreprocess.add(preprocess); - } - - // Postprocessing - List> listPostprocess = new ArrayList>(); - if (params.firstPostprocessing == null) { - params.firstPostprocessing = params.secondPostprocessing; - params.secondPostprocessing = null; - } - c = 0; - if ((params.firstPostprocessing != null) && (params.firstPostprocessing.contains(".ijm") || params.firstPostprocessing.contains(".txt"))) { - Map postprocess = new LinkedHashMap<>(); - postprocess.put("spec", "ij.IJ::runMacroFile"); - postprocess.put("kwargs", new File(params.firstPostprocessing).getName()); - listPostprocess.add(postprocess); - } else if ((params.firstPostprocessing != null) && (params.firstPostprocessing.contains(".class") || params.firstPostprocessing.contains(".jar"))) { - String filename = new File(params.firstPostprocessing).getName(); - Map postprocess = new LinkedHashMap<>(); - postprocess.put("spec", filename + " " + params.javaPostprocessingClass.get(c ++) + "::postProcessingRoutineUsingImage"); - listPostprocess.add(postprocess); - } else if (params.firstPostprocessing == null && params.secondPostprocessing == null) { - Map postprocess = new LinkedHashMap<>(); - postprocess.put("spec", null); - listPostprocess.add(postprocess); - } - if ((params.secondPostprocessing != null) && (params.secondPostprocessing.contains(".ijm") || params.secondPostprocessing.contains(".txt"))) { - Map postprocess = new LinkedHashMap<>(); - postprocess.put("spec", "ij.IJ::runMacroFile"); - postprocess.put("kwargs", new File(params.secondPostprocessing).getName()); - listPostprocess.add(postprocess); - } else if ((params.secondPostprocessing != null) && (params.secondPostprocessing.contains(".class") || params.secondPostprocessing.contains(".jar"))) { - String filename = new File(params.secondPostprocessing).getName(); - Map postprocess = new LinkedHashMap<>(); - postprocess.put("spec", filename + " " + params.javaPostprocessingClass.get(c ++) + "::postProcessingRoutineUsingImage"); - listPostprocess.add(postprocess); - } - - // Prediction, preprocessing and postprocessing together - Map prediction = new LinkedHashMap<>(); - prediction.put("preprocess", listPreprocess); - prediction.put("postprocess", listPostprocess); - - // Information relevant to deepimagej - deepimagej.put("prediction", prediction); - config.put("deepimagej", deepimagej); - data.put("config", config); - - DumperOptions options = new DumperOptions(); - options.setDefaultFlowStyle(DumperOptions.FlowStyle.BLOCK); - options.setDefaultScalarStyle(DumperOptions.ScalarStyle.PLAIN); - options.setIndent(4); - //options.setPrettyFlow(true); - Yaml yaml = new Yaml(options); - FileWriter writer = null; - try { - writer = new FileWriter(new File(params.saveDir, "rdf.yaml")); - yaml.dump(data, writer); - writer.close(); - removeQuotes(new File(params.saveDir, "rdf.yaml")); - } catch (IOException e) { - e.printStackTrace(); - } - } - public static Map readConfig(String yamlFile) { File initialFile = new File(yamlFile); InputStream targetStream = null; diff --git a/src/main/java/deepimagej/tools/weights/KerasWeights.java b/src/main/java/deepimagej/tools/weights/KerasWeights.java new file mode 100644 index 00000000..6335da64 --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/KerasWeights.java @@ -0,0 +1,250 @@ +package deepimagej.tools.weights; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Class that contains the information for Keras weights. + * For more information about the parameters go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + * + */ +public class KerasWeights implements WeightFormatInterface { + + /** + * Crate an object that specifies Keras weights + * + * @param weights + * part of the yaml file that contains exclusively the + * information referring to the Keras weights + */ + public KerasWeights(Map weights) { + weightsFormat = "keras_hdf5"; + Set keys = weights.keySet(); + for (String k : keys) { + Object fieldElement = weights.get(k); + switch (k) + { + case "tensorflow_version": + setTrainingVersion(fieldElement); + break; + case "source": + setSource(fieldElement); + break; + case "attachments": + setAttachments(fieldElement); + break; + case "authors": + setAuthors(fieldElement); + break; + case "parent": + setParent(fieldElement); + break; + case "sha256": + setSha256(fieldElement); + break; + case "architecture": + setArchitecture(fieldElement); + break; + case "architecture_sha256": + setArchitectureSha256(fieldElement); + break; + } + } + } + + private String weightsFormat; + @Override + public String getWeightsFormat() { + return weightsFormat; + } + + private String trainingVersion; + @Override + public String getTrainingVersion() { + return trainingVersion; + } + + /** + * Set the training version for the weights + * specified in the yaml if it exists + * + * @param v + * training version of the weights + */ + public void setTrainingVersion(Object v) { + if (v instanceof String && !((String)v).contains("+") + && !((String)v).contains("cu") + && !((String)v).contains("cuda")) + this.trainingVersion = (String) v; + else if (v instanceof String && ((String)v).contains("+")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("+")).trim(); + else if (v instanceof String && ((String)v).contains("cuda")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cuda")).trim(); + else if (v instanceof String && ((String)v).contains("cu")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cu")).trim(); + else if (v instanceof Double) + this.trainingVersion = "" + v; + else if (v instanceof Float) + this.trainingVersion = "" + v; + else if (v instanceof Long) + this.trainingVersion = "" + v; + else if (v instanceof Integer) + this.trainingVersion = "" + v; + } + + private String sha256; + @Override + public String getSha256() { + return sha256; + } + + /** + * Set the SHA256 of the model from the parameters in the yaml + * + * @param s + * SHA256 of the model + */ + public void setSha256(Object s) { + if (s instanceof String) + sha256 = (String) s; + + } + + private String source; + @Override + public String getSource() { + return source; + } + + /** + * Set the source of the model from the parameters in the yaml + * + * @param s + * string from the yaml file containing the source, return only the + * name of the file inside the folder, not the whole path + */ + public void setSource(Object s) { + if (s instanceof String) + this.source = (String) s; + + } + + private List authors; + @Override + public List getAuthors() { + return authors; + } + + /** + * Set the authors of the model + * @param authors + * authors of the model + */ + public void setAuthors(Object authors) { + if (authors instanceof String) { + List authList = new ArrayList(); + authList.add((String) authors); + this.authors = authList; + } else if (authors instanceof List) { + this.authors = (List) authors; + } + + } + + private Map attachments; + @Override + public Map getAttachments() { + return attachments; + } + + /** + * Set the attachments of the weights if they exist + * @param attachments + * attachments of the model + */ + public void setAttachments(Object attachments) { + if (attachments instanceof Map) + this.attachments = (Map) attachments; + + } + + private String parent; + @Override + public String getParent() { + return parent; + } + + /** + * Set the parent of the weights in the case they exist + * @param parent + * parent weights of the model + */ + public void setParent(Object parent) { + if (parent instanceof String) + this.parent = (String) parent; + } + + private String architecture; + @Override + public String getArchitecture() { + return architecture; + } + + /** + * Set the path to the architecture of the weights in the case it exists + * @param architecture + * path to the architecture of the model + */ + public void setArchitecture(Object architecture) { + if (architecture instanceof String) + this.architecture = (String) architecture; + } + + private String architectureSha256; + @Override + public String getArchitectureSha256() { + return architectureSha256; + } + + /** + * Set the architecture Sha256 in the case it exists + * @param architectureSha256 + * architecture Sha256 of the model + */ + public void setArchitectureSha256(Object architectureSha256) { + if (architectureSha256 instanceof String) + this.architectureSha256 = (String) architectureSha256; + } + + @Override + public String getSourceFileName() { + if (source == null) + return source; + return new File(source).getName(); + } + + boolean gpu = false; + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public void supportGPU(boolean support) { + gpu = support; + } + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public boolean isSupportGPU() { + return gpu; + } +} diff --git a/src/main/java/deepimagej/tools/weights/ModelWeight.java b/src/main/java/deepimagej/tools/weights/ModelWeight.java new file mode 100644 index 00000000..a8b0ebb0 --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/ModelWeight.java @@ -0,0 +1,519 @@ +package deepimagej.tools.weights; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * The model weights information for the current model. + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public class ModelWeight +{ + /** + * String representing the selected weight by the user in the DeepIcy GUI + */ + private String selectedEngine; + /** + * String representing the selected weight version by the user in the DeepIcy GUI + */ + private String selectedVersion; + /** + * Object containing the information about the weights selected + */ + private WeightFormatInterface selectedWeights; + /** + * Object containing the information about the weights loaded + */ + private static Map loadedWeights = new HashMap(); + /** + * Map with all the engines defined in the rdf.yaml + */ + private HashMap weightsDic; + private static String kerasIdentifier = "keras_hdf5"; + private static String onnxIdentifier = "onnx"; + private static String torchIdentifier = "pytorch_state_dict"; + private static String tfIdentifier = "tensorflow_saved_model_bundle"; + private static String tfJsIdentifier = "tensorflow_js"; + private static String torchscriptIdentifier = "torchscript"; + private static String bioengineIdentifier = "bioengine"; + private static String gpuSuffix = " (supports gpu)"; + /** + * The key for the weights that are going to be used in the BioEngine + */ + private String bioEngineWeightsKey; + /** + * List of all the not supported Deep Learning frameworks by DeepIcy + */ + private static ArrayList supported = + new ArrayList(Arrays.asList(torchscriptIdentifier, tfIdentifier, onnxIdentifier)); + /** + * Suffix added to the engine version when the engine version is not installed + */ + private static String missingVersion = " (please install)"; + /** + * Suffix added to the engine version when the engine version is not installed + */ + private static String notSupported = " (not supported)"; + /** + * Suffix added to the engine version when another version of + * the same engine has already been loaded + */ + private static String alreadyLoaded = " (restart Icy)"; + /** + * Builds a weight information element from the element map. + * + * @param yamlFieldElements + * The element map. + * @return The model weight information instance. + */ + public static ModelWeight build(Map yamlFieldElements) + { + ModelWeight model = new ModelWeight(); + Set weightsFormats = yamlFieldElements.keySet(); + // Reset the list with the inlcuded frameworks + model.weightsDic = new HashMap(); + for (String ww : weightsFormats) { + Map weights = (Map) yamlFieldElements.get(ww); + if (ww.contentEquals(kerasIdentifier)) { + KerasWeights weightsObject = new KerasWeights(weights); + model.weightsDic.put(model.kerasEngineName(weightsObject), weightsObject); + } else if (ww.contentEquals(onnxIdentifier)) { + OnnxWeights weightsObject = new OnnxWeights(weights); + model.weightsDic.put(model.onnxEngineName(weightsObject), weightsObject); + } else if (ww.contentEquals(torchIdentifier)) { + PytorchWeights weightsObject = new PytorchWeights(weights); + model.weightsDic.put(model.torchEngineName(weightsObject), weightsObject); + } else if (ww.contentEquals(tfIdentifier)) { + TfWeights weightsObject = new TfWeights(weights); + model.weightsDic.put(model.tfEngineName(weightsObject), weightsObject); + } else if (ww.contentEquals(tfJsIdentifier)) { + TfJsWeights weightsObject = new TfJsWeights(weights); + model.weightsDic.put(model.tfJsEngineName(weightsObject), weightsObject); + } else if (ww.contentEquals(torchscriptIdentifier) + || ww.contentEquals("pytorch_script")) { + TorchscriptWeights weightsObject = new TorchscriptWeights(weights); + model.weightsDic.put(model.torchscriptEngineName(weightsObject), weightsObject); + } + } + return model; + } + + /** + * Identifies the weights that are compatible with the Bioengine. The BioEngine + * canot run Tf 1 weights + */ + public void findBioEngineWeights() { + for (Entry entry : weightsDic.entrySet()) { + if (entry.getValue().getWeightsFormat().equals(kerasIdentifier)) { + bioEngineWeightsKey = kerasIdentifier; + return; + } else if (entry.getValue().getWeightsFormat().equals(onnxIdentifier)) { + bioEngineWeightsKey = onnxIdentifier; + return; + } else if (entry.getValue().getWeightsFormat().equals(torchscriptIdentifier)) { + bioEngineWeightsKey = torchscriptIdentifier; + return; + } + } + } + + /** + * Return the key for the Bioengine weights that are going to be used + * @return the key for the bioengine weights + */ + public String getBioEngineWeightsKey() { + return this.bioEngineWeightsKey; + } + + /** + * Return the corresponding weight format + * @return the corresponding weight format. + * @throws Exception if the set of wanted weights is not present + */ + public WeightFormatInterface getWeightsByIdentifier(String weightsFormat) throws IOException + { + WeightFormatInterface ww = weightsDic.get(weightsFormat); + + if (ww == null) { + throw new IOException("The selected model does not contain " + + "a set of " + weightsFormat + " weights."); + } + return ww; + } + + /** + * Return a list containing all the frameworks (engines) where the model has weights + * @return list of supported Deep Learning frameworks with the corresponding version + */ + public List getEnginesListWithVersions(){ + return this.weightsDic.keySet().stream().collect(Collectors.toList()); + } + + /** + * Get list with the supported Deep Learning frameworks. Does not the same framework + * several times if it is repeated. + * @return + */ + public List getSupportedDLFrameworks() { + return weightsDic.entrySet().stream(). + map(i -> i.getValue().getWeightsFormat()). + distinct().collect(Collectors.toList()); + } + + /** + * Get the weights format selected to make inference. + * For models that contain several sets of weights + * from different frameworks in the + * same model folder + * + * @return the selected weights engine + */ + public String getSelectedWeightsIdentifier() { + return selectedEngine; + } + + /** + * GEt the training version of the selected weights + * @return the training version of the selected weights + * @throws IOException if the weights do not exist + */ + public String getWeightsSelectedVersion() throws IOException { + return selectedVersion; + } + + /** + * Return the object containing the information about the selected weights + * @return the yaml information about the selected weights + */ + public WeightFormatInterface getSelectedWeights() { + return this.selectedWeights; + } + + /** + * Sets the Deep Learning framework of the weights of the + * model selected. + * For models that contain several sets of weights + * from different frameworks in the + * + * @param selectedWeights the format (framework) of the weights + * @throws IOException if the weights are not found in the avaiable ones + */ + public void setSelectedWeightsFormat(String selectedWeights) throws IOException { + if (selectedWeights.startsWith(kerasIdentifier)) { + this.selectedEngine = kerasIdentifier; + } else if (selectedWeights.startsWith(onnxIdentifier)) { + this.selectedEngine = onnxIdentifier; + } else if (selectedWeights.startsWith(torchIdentifier)) { + this.selectedEngine = torchIdentifier; + } else if (selectedWeights.startsWith(tfIdentifier)) { + this.selectedEngine = tfIdentifier; + } else if (selectedWeights.startsWith(tfJsIdentifier)) { + this.selectedEngine = tfJsIdentifier; + } else if (selectedWeights.startsWith(torchscriptIdentifier)) { + this.selectedEngine = torchscriptIdentifier; + } else if (selectedWeights.startsWith(bioengineIdentifier)) { + this.selectedEngine = bioengineIdentifier; + } else { + throw new IllegalArgumentException("Unsupported Deep Learning framework in DeepIcy."); + } + setSelectedVersion(selectedWeights); + setSelectedWeights(selectedWeights); + } + + /** + * Sets the Deep Learning engine version selected by the user + * @param selectedWeights + * the selected weights format and version by the user in the GUI + */ + private void setSelectedVersion(String selectedWeights) { + if (selectedWeights.equals(bioengineIdentifier)) { + this.selectedVersion = ""; + return; + } + String preffix = this.selectedEngine + "_v"; + this.selectedVersion = selectedWeights.substring(preffix.length()); + + } + + /** + * Set the pair of weights selected by the user by saving the object that contains the info + * about them + * @param selectedWeights + * the string selected by the user as weights + * @throws IOException if the weights are not found in the avaiable ones + */ + private void setSelectedWeights(String selectedWeights) throws IOException { + this.selectedWeights = getWeightsByIdentifier(selectedWeights); + } + + /** + * Set the weights as loaded. Once there are loaded weights, no other weights of + * that same engine can be loaded + */ + public void setWeightsAsLoaded() { + loadedWeights.put(selectedWeights.getWeightsFormat(), selectedWeights); + } + + /** TODO finish when the BioEngine is better defined in the BioImage.io + * Create the name for the BioEngine weights. The name contains the name of the BioEngine + * + * @return the complete weights name + */ + private String bioEngineName(String server) { + if (server.startsWith("https://")) + server = server.substring("https://".length()); + else if (server.startsWith("http://")) + server = server.substring("http://".length()); + + String name = bioengineIdentifier + " (" + server + ")"; + return name; + } + + /** + * Create the name of a pair of torchscript names. The name contains the name of the weights and + * version number. If no version is provided, "Unknown" is used as version identifier + * @param ww + * weights object + * @return the complete weights name + */ + private String torchscriptEngineName(TorchscriptWeights ww) { + String name = torchscriptIdentifier + "_v"; + String suffix = ww.getTrainingVersion(); + if (suffix == null) { + boolean exist = true; + suffix = "Unknown"; + int c = 0; + while (exist) { + if (!this.weightsDic.keySet().contains(name + suffix + c)) { + suffix = suffix + c; + exist = false; + } + c ++; + } + } + return name + suffix; + } + + /** + * Create the name of a pair of torchscript names. The name contains the name of the weights and + * version number. If no version is provided, "Unknown" is used as version identifier + * @param ww + * weights object + * @return the complete weights name + */ + private String tfJsEngineName(TfJsWeights ww) { + String name = tfJsIdentifier + "_v"; + String suffix = ww.getTrainingVersion(); + if (suffix == null) { + boolean exist = true; + suffix = "Unknown"; + int c = 0; + while (exist) { + if (!this.weightsDic.keySet().contains(name + suffix + c)) { + suffix = suffix + c; + exist = false; + } + c ++; + } + } + return name + suffix; + } + + /** + * Create the name of a pair of torchscript names. The name contains the name of the weights and + * version number. If no version is provided, "Unknown" is used as version identifier + * @param ww + * weights object + * @return the complete weights name + */ + private String onnxEngineName(OnnxWeights ww) { + String name = onnxIdentifier + "_v"; + String suffix = ww.getTrainingVersion(); + if (suffix == null) { + boolean exist = true; + suffix = "Unknown"; + int c = 0; + while (exist) { + if (!this.weightsDic.keySet().contains(name + suffix + c)) { + suffix = suffix + c; + exist = false; + } + c ++; + } + } + return name + suffix; + } + + /** + * Create the name of a pair of torchscript names. The name contains the name of the weights and + * version number. If no version is provided, "Unknown" is used as version identifier + * @param ww + * weights object + * @return the complete weights name + */ + private String tfEngineName(TfWeights ww) { + String name = tfIdentifier + "_v"; + String suffix = ww.getTrainingVersion(); + if (suffix == null) { + boolean exist = true; + suffix = "Unknown"; + int c = 0; + while (exist) { + if (!weightsDic.keySet().contains(name + suffix + c)) { + suffix = suffix + c; + exist = false; + } + c ++; + } + } + return name + suffix; + } + + /** + * Create the name of a pair of torchscript names. The name contains the name of the weights and + * version number. If no version is provided, "Unknown" is used as version identifier + * @param ww + * weights object + * @return the complete weights name + */ + private String torchEngineName(PytorchWeights ww) { + String name = torchIdentifier + "_v"; + String suffix = ww.getTrainingVersion(); + if (suffix == null) { + boolean exist = true; + suffix = "Unknown"; + int c = 0; + while (exist) { + if (!weightsDic.keySet().contains(name + suffix + c)) { + suffix = suffix + c; + exist = false; + } + c ++; + } + } + return name + suffix; + } + + /** + * Create the name of a pair of torchscript names. The name contains the name of the weights and + * version number. If no version is provided, "Unknown" is used as version identifier + * @param ww + * weights object + * @return the complete weights name + */ + private String kerasEngineName(KerasWeights ww) { + String name = kerasIdentifier + "_v"; + String suffix = ww.getTrainingVersion(); + if (suffix == null) { + boolean exist = true; + suffix = "Unknown"; + int c = 0; + while (exist) { + if (!weightsDic.keySet().contains(name + suffix + c)) { + suffix = suffix + c; + exist = false; + } + c ++; + } + } + return name + suffix; + } + + /** + * REturn the tag used to identify Deep Learning engines that are not present + * in the local engines repo + * @return + */ + public static String getMissingEngineTag() { + return missingVersion; + } + + /** + * REturn the tag used to identify Deep Learning engines that are not supported by DeepIcy + * @return + */ + public static String getNotSupportedEngineTag() { + return notSupported; + } + + /** + * REturn the tag used to identify Deep Learning engines where another + * version oof the engine has been loaded + * @return + */ + public static String getAlreadyLoadedEngineTag() { + return alreadyLoaded; + } + + /** + * REturn the tag used to identify Deep Learning engines that support GPU + * @return + */ + public static String getGPUSuffix() { + return gpuSuffix; + } + + /** + * + * @return the identifier key used for the Keras Deep Learning framework + */ + public static String getKerasID() { + return kerasIdentifier; + } + + /** + * + * @return the identifier key used for the Onnx Deep Learning framework + */ + public static String getOnnxID() { + return onnxIdentifier; + } + + /** + * + * @return the identifier key used for the Pytorch Deep Learning framework + */ + public static String getPytorchID() { + return torchIdentifier; + } + + /** + * + * @return the identifier key used for the Tensorflow JS Deep Learning framework + */ + public static String getTensorflowJsID() { + return tfJsIdentifier; + } + + /** + * + * @return the identifier key used for the Tensorflow Deep Learning framework + */ + public static String getTensorflowID() { + return tfIdentifier; + } + + /** + * + * @return the identifier key used for the torchscript Deep Learning framework + */ + public static String getTorchscriptID() { + return torchscriptIdentifier; + } + + /** + * + * @return the identifier key used for the Bioengine Deep Learning framework + */ + public static String getBioengineID() { + return bioengineIdentifier; + } +} diff --git a/src/main/java/deepimagej/tools/weights/OnnxWeights.java b/src/main/java/deepimagej/tools/weights/OnnxWeights.java new file mode 100644 index 00000000..6c93f42e --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/OnnxWeights.java @@ -0,0 +1,254 @@ +package deepimagej.tools.weights; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Class that contains the information for ONNX weights. + * For more information about the parameters go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + * + */ +public class OnnxWeights implements WeightFormatInterface { + + /** + * Crate an object that specifies ONNX weights + * + * @param weights + * part of the yaml file that contains exclusively the + * information referring to the ONNX weights + */ + public OnnxWeights(Map weights) { + weightsFormat = "onnx"; + Set keys = weights.keySet(); + for (String k : keys) { + Object fieldElement = weights.get(k); + switch (k) + { + case "opset_version": + setTrainingVersion(fieldElement); + break; + case "source": + setSource(fieldElement); + break; + case "attachments": + setAttachments(fieldElement); + break; + case "authors": + setAuthors(fieldElement); + break; + case "parent": + setParent(fieldElement); + break; + case "sha256": + setSha256(fieldElement); + break; + case "architecture": + setArchitecture(fieldElement); + break; + case "architecture_sha256": + setArchitectureSha256(fieldElement); + break; + } + } + // TODO add fixed version if it is not shown because many models are missing + // the Pytorch version. Remove when they start appearing + if (trainingVersion == null) + trainingVersion = "17"; + } + + private String weightsFormat; + @Override + public String getWeightsFormat() { + return weightsFormat; + } + + private String trainingVersion; + @Override + public String getTrainingVersion() { + return trainingVersion; + } + + /** + * Set the training version for the weights + * specified in the yaml if it exists + * + * @param v + * training version of the weights + */ + public void setTrainingVersion(Object v) { + if (v instanceof String && !((String)v).contains("+") + && !((String)v).contains("cu") + && !((String)v).contains("cuda")) + this.trainingVersion = (String) v; + else if (v instanceof String && ((String)v).contains("+")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("+")).trim(); + else if (v instanceof String && ((String)v).contains("cuda")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cuda")).trim(); + else if (v instanceof String && ((String)v).contains("cu")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cu")).trim(); + else if (v instanceof Double) + this.trainingVersion = "" + v; + else if (v instanceof Float) + this.trainingVersion = "" + v; + else if (v instanceof Long) + this.trainingVersion = "" + v; + else if (v instanceof Integer) + this.trainingVersion = "" + v; + } + + private String sha256; + @Override + public String getSha256() { + return sha256; + } + + /** + * Set the SHA256 of the model from the parameters in the yaml + * + * @param s + * SHA256 of the model + */ + public void setSha256(Object s) { + if (s instanceof String) + sha256 = (String) s; + + } + + private String source; + @Override + public String getSource() { + return source; + } + + /** + * Set the source of the model from the parameters in the yaml + * + * @param s + * string from the yaml file containing the source, return only the + * name of the file inside the folder, not the whole path + */ + public void setSource(Object s) { + if (s instanceof String) + this.source = (String) s; + + } + + private List authors; + @Override + public List getAuthors() { + return authors; + } + + /** + * Set the authors of the model + * @param authors + * authors of the model + */ + public void setAuthors(Object authors) { + if (authors instanceof String) { + List authList = new ArrayList(); + authList.add((String) authors); + this.authors = authList; + } else if (authors instanceof List) { + this.authors = (List) authors; + } + + } + + private Map attachments; + @Override + public Map getAttachments() { + return attachments; + } + + /** + * Set the attachments of the weights if they exist + * @param attachments + * attachments of the model + */ + public void setAttachments(Object attachments) { + if (attachments instanceof Map) + this.attachments = (Map) attachments; + + } + + private String parent; + @Override + public String getParent() { + return parent; + } + + /** + * Set the parent of the weights in the case they exist + * @param parent + * parent weights of the model + */ + public void setParent(Object parent) { + if (parent instanceof String) + this.parent = (String) parent; + } + + private String architecture; + @Override + public String getArchitecture() { + return architecture; + } + + /** + * Set the path to the architecture of the weights in the case it exists + * @param architecture + * path to the architecture of the model + */ + public void setArchitecture(Object architecture) { + if (architecture instanceof String) + this.architecture = (String) architecture; + } + + private String architectureSha256; + @Override + public String getArchitectureSha256() { + return architectureSha256; + } + + /** + * Set the architecture Sha256 in the case it exists + * @param architectureSha256 + * architecture Sha256 of the model + */ + public void setArchitectureSha256(Object architectureSha256) { + if (architectureSha256 instanceof String) + this.architectureSha256 = (String) architectureSha256; + } + + @Override + public String getSourceFileName() { + if (source == null) + return source; + return new File(source).getName(); + } + + boolean gpu = false; + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public void supportGPU(boolean support) { + gpu = support; + } + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public boolean isSupportGPU() { + return gpu; + } +} diff --git a/src/main/java/deepimagej/tools/weights/PytorchWeights.java b/src/main/java/deepimagej/tools/weights/PytorchWeights.java new file mode 100644 index 00000000..f61e36af --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/PytorchWeights.java @@ -0,0 +1,250 @@ +package deepimagej.tools.weights; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Class that contains the information for Pytorch state dic weights. + * For more information about the parameters go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + * + */ +public class PytorchWeights implements WeightFormatInterface{ + + /** + * Crate an object that specifies Pytorch state dic weights + * + * @param weights + * part of the yaml file that contains exclusively the + * information refering to the Pytorch weights + */ + public PytorchWeights(Map weights) { + weightsFormat = "pytorch_state_dict"; + Set keys = weights.keySet(); + for (String k : keys) { + Object fieldElement = weights.get(k); + switch (k) + { + case "pytorch_version": + setTrainingVersion(fieldElement); + break; + case "source": + setSource(fieldElement); + break; + case "attachments": + setAttachments(fieldElement); + break; + case "authors": + setAuthors(fieldElement); + break; + case "parent": + setParent(fieldElement); + break; + case "sha256": + setSha256(fieldElement); + break; + case "architecture": + setArchitecture(fieldElement); + break; + case "architecture_sha256": + setArchitectureSha256(fieldElement); + break; + } + } + } + + private String weightsFormat; + @Override + public String getWeightsFormat() { + return weightsFormat; + } + + private String trainingVersion; + @Override + public String getTrainingVersion() { + return trainingVersion; + } + + /** + * Set the training version for the weights + * specified in the yaml if it exists + * + * @param v + * training version of the weights + */ + public void setTrainingVersion(Object v) { + if (v instanceof String && !((String)v).contains("+") + && !((String)v).contains("cu") + && !((String)v).contains("cuda")) + this.trainingVersion = (String) v; + else if (v instanceof String && ((String)v).contains("+")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("+")).trim(); + else if (v instanceof String && ((String)v).contains("cuda")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cuda")).trim(); + else if (v instanceof String && ((String)v).contains("cu")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cu")).trim(); + else if (v instanceof Double) + this.trainingVersion = "" + v; + else if (v instanceof Float) + this.trainingVersion = "" + v; + else if (v instanceof Long) + this.trainingVersion = "" + v; + else if (v instanceof Integer) + this.trainingVersion = "" + v; + } + + private String sha256; + @Override + public String getSha256() { + return sha256; + } + + /** + * Set the SHA256 of the model from the parameters in the yaml + * + * @param s + * SHA256 of the model + */ + public void setSha256(Object s) { + if (s instanceof String) + sha256 = (String) s; + + } + + private String source; + @Override + public String getSource() { + return source; + } + + /** + * Set the source of the model from the parameters in the yaml + * + * @param s + * string from the yaml file containing the source, return only the + * name of the file inside the folder, not the whole path + */ + public void setSource(Object s) { + if (s instanceof String) + this.source = (String) s; + + } + + private List authors; + @Override + public List getAuthors() { + return authors; + } + + /** + * Set the authors of the model + * @param authors + * authors of the model + */ + public void setAuthors(Object authors) { + if (authors instanceof String) { + List authList = new ArrayList(); + authList.add((String) authors); + this.authors = authList; + } else if (authors instanceof List) { + this.authors = (List) authors; + } + + } + + private Map attachments; + @Override + public Map getAttachments() { + return attachments; + } + + /** + * Set the attachments of the weights if they exist + * @param attachments + * attachments of the model + */ + public void setAttachments(Object attachments) { + if (attachments instanceof Map) + this.attachments = (Map) attachments; + + } + + private String parent; + @Override + public String getParent() { + return parent; + } + + /** + * Set the parent of the weights in the case they exist + * @param parent + * parent weights of the model + */ + public void setParent(Object parent) { + if (parent instanceof String) + this.parent = (String) parent; + } + + private String architecture; + @Override + public String getArchitecture() { + return architecture; + } + + /** + * Set the path to the architecture of the weights in the case it exists + * @param architecture + * path to the architecture of the model + */ + public void setArchitecture(Object architecture) { + if (architecture instanceof String) + this.architecture = (String) architecture; + } + + private String architectureSha256; + @Override + public String getArchitectureSha256() { + return architectureSha256; + } + + /** + * Set the architecture Sha256 in the case it exists + * @param architectureSha256 + * architecture Sha256 of the model + */ + public void setArchitectureSha256(Object architectureSha256) { + if (architectureSha256 instanceof String) + this.architectureSha256 = (String) architectureSha256; + } + + @Override + public String getSourceFileName() { + if (source == null) + return source; + return new File(source).getName(); + } + + boolean gpu = false; + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public void supportGPU(boolean support) { + gpu = support; + } + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public boolean isSupportGPU() { + return gpu; + } +} diff --git a/src/main/java/deepimagej/tools/weights/TfJsWeights.java b/src/main/java/deepimagej/tools/weights/TfJsWeights.java new file mode 100644 index 00000000..fcabd511 --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/TfJsWeights.java @@ -0,0 +1,250 @@ +package deepimagej.tools.weights; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Class that contains the information for Tensorflow Javascript weights. + * For more information about the parameters go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + * + */ +public class TfJsWeights implements WeightFormatInterface{ + + /** + * Crate an object that specifies Tensorflow Javascript weights + * + * @param weights + * part of the yaml file that contains exclusively the + * information referring to the Keras weights + */ + public TfJsWeights(Map weights) { + weightsFormat = "tensorflow_js"; + Set keys = weights.keySet(); + for (String k : keys) { + Object fieldElement = weights.get(k); + switch (k) + { + case "tensorflow_version": + setTrainingVersion(fieldElement); + break; + case "source": + setSource(fieldElement); + break; + case "attachments": + setAttachments(fieldElement); + break; + case "authors": + setAuthors(fieldElement); + break; + case "parent": + setParent(fieldElement); + break; + case "sha256": + setSha256(fieldElement); + break; + case "architecture": + setArchitecture(fieldElement); + break; + case "architecture_sha256": + setArchitectureSha256(fieldElement); + break; + } + } + } + + private String weightsFormat; + @Override + public String getWeightsFormat() { + return weightsFormat; + } + + private String trainingVersion; + @Override + public String getTrainingVersion() { + return trainingVersion; + } + + /** + * Set the training version for the weights + * specified in the yaml if it exists + * + * @param v + * training version of the weights + */ + public void setTrainingVersion(Object v) { + if (v instanceof String && !((String)v).contains("+") + && !((String)v).contains("cu") + && !((String)v).contains("cuda")) + this.trainingVersion = (String) v; + else if (v instanceof String && ((String)v).contains("+")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("+")).trim(); + else if (v instanceof String && ((String)v).contains("cuda")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cuda")).trim(); + else if (v instanceof String && ((String)v).contains("cu")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cu")).trim(); + else if (v instanceof Double) + this.trainingVersion = "" + v; + else if (v instanceof Float) + this.trainingVersion = "" + v; + else if (v instanceof Long) + this.trainingVersion = "" + v; + else if (v instanceof Integer) + this.trainingVersion = "" + v; + } + + private String sha256; + @Override + public String getSha256() { + return sha256; + } + + /** + * Set the SHA256 of the model from the parameters in the yaml + * + * @param s + * SHA256 of the model + */ + public void setSha256(Object s) { + if (s instanceof String) + sha256 = (String) s; + + } + + private String source; + @Override + public String getSource() { + return source; + } + + /** + * Set the source of the model from the parameters in the yaml + * + * @param s + * string from the yaml file containing the source, return only the + * name of the file inside the folder, not the whole path + */ + public void setSource(Object s) { + if (s instanceof String) + this.source = (String) s; + + } + + private List authors; + @Override + public List getAuthors() { + return authors; + } + + /** + * Set the authors of the model + * @param authors + * authors of the model + */ + public void setAuthors(Object authors) { + if (authors instanceof String) { + List authList = new ArrayList(); + authList.add((String) authors); + this.authors = authList; + } else if (authors instanceof List) { + this.authors = (List) authors; + } + + } + + private Map attachments; + @Override + public Map getAttachments() { + return attachments; + } + + /** + * Set the attachments of the weights if they exist + * @param attachments + * attachments of the model + */ + public void setAttachments(Object attachments) { + if (attachments instanceof Map) + this.attachments = (Map) attachments; + + } + + private String parent; + @Override + public String getParent() { + return parent; + } + + /** + * Set the parent of the weights in the case they exist + * @param parent + * parent weights of the model + */ + public void setParent(Object parent) { + if (parent instanceof String) + this.parent = (String) parent; + } + + private String architecture; + @Override + public String getArchitecture() { + return architecture; + } + + /** + * Set the path to the architecture of the weights in the case it exists + * @param architecture + * path to the architecture of the model + */ + public void setArchitecture(Object architecture) { + if (architecture instanceof String) + this.architecture = (String) architecture; + } + + private String architectureSha256; + @Override + public String getArchitectureSha256() { + return architectureSha256; + } + + /** + * Set the architecture Sha256 in the case it exists + * @param architectureSha256 + * architecture Sha256 of the model + */ + public void setArchitectureSha256(Object architectureSha256) { + if (architectureSha256 instanceof String) + this.architectureSha256 = (String) architectureSha256; + } + + @Override + public String getSourceFileName() { + if (source == null) + return source; + return new File(source).getName(); + } + + boolean gpu = false; + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public void supportGPU(boolean support) { + gpu = support; + } + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public boolean isSupportGPU() { + return gpu; + } +} diff --git a/src/main/java/deepimagej/tools/weights/TfWeights.java b/src/main/java/deepimagej/tools/weights/TfWeights.java new file mode 100644 index 00000000..9b80241e --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/TfWeights.java @@ -0,0 +1,260 @@ +package deepimagej.tools.weights; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import io.bioimage.modelrunner.engine.EngineInfo; +import io.bioimage.modelrunner.engine.installation.EngineManagement; + +/** + * Class that contains the information for Tensorflow weights. + * For more information about the parameters go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + * + */ +public class TfWeights implements WeightFormatInterface{ + + /** + * Crate an object that specifies Tensorflow weights + * + * @param weights + * part of the yaml file that contains exclusively the + * information referring to the Tensorflow weights + */ + public TfWeights(Map weights) { + weightsFormat = "tensorflow_saved_model_bundle"; + Set keys = weights.keySet(); + for (String k : keys) { + Object fieldElement = weights.get(k); + switch (k) + { + case "tensorflow_version": + setTrainingVersion(fieldElement); + break; + case "source": + setSource(fieldElement); + break; + case "attachments": + setAttachments(fieldElement); + break; + case "authors": + setAuthors(fieldElement); + break; + case "parent": + setParent(fieldElement); + break; + case "sha256": + setSha256(fieldElement); + break; + case "architecture": + setArchitecture(fieldElement); + break; + case "architecture_sha256": + setArchitectureSha256(fieldElement); + break; + } + } + } + + private String weightsFormat; + @Override + public String getWeightsFormat() { + return weightsFormat; + } + + private String trainingVersion; + @Override + public String getTrainingVersion() { + return trainingVersion; + } + + /** + * Set the training version for the weights + * specified in the yaml if it exists + * + * @param v + * training version of the weights + */ + public void setTrainingVersion(Object v) { + if (v instanceof String && !((String)v).contains("+") + && !((String)v).contains("cu") + && !((String)v).contains("cuda")) + this.trainingVersion = (String) v; + else if (v instanceof String && ((String)v).contains("+")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("+")).trim(); + else if (v instanceof String && ((String)v).contains("cuda")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cuda")).trim(); + else if (v instanceof String && ((String)v).contains("cu")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cu")).trim(); + else if (v instanceof Double) + this.trainingVersion = "" + v; + else if (v instanceof Float) + this.trainingVersion = "" + v; + else if (v instanceof Long) + this.trainingVersion = "" + v; + else if (v instanceof Integer) + this.trainingVersion = "" + v; + + // Pin the tensorflow versions (TF 1 or TF 2, according to the engines version downloaded) + if (trainingVersion == null && trainingVersion.startsWith("1")){ + trainingVersion = EngineManagement.ENGINES_VERSIONS.get(EngineInfo.getTensorflowKey() + "_1"); + } + else if (trainingVersion == null) { + trainingVersion = EngineManagement.ENGINES_VERSIONS.get(EngineInfo.getTensorflowKey() + "_2"); + } + } + + private String sha256; + @Override + public String getSha256() { + return sha256; + } + + /** + * Set the SHA256 of the model from the parameters in the yaml + * + * @param s + * SHA256 of the model + */ + public void setSha256(Object s) { + if (s instanceof String) + sha256 = (String) s; + + } + + private String source; + @Override + public String getSource() { + return source; + } + + /** + * Set the source of the model from the parameters in the yaml + * + * @param s + * string from the yaml file containing the source, return only the + * name of the file inside the folder, not the whole path + */ + public void setSource(Object s) { + if (s instanceof String) + this.source = (String) s; + + } + + private List authors; + @Override + public List getAuthors() { + return authors; + } + + /** + * Set the authors of the model + * @param authors + * authors of the model + */ + public void setAuthors(Object authors) { + if (authors instanceof String) { + List authList = new ArrayList(); + authList.add((String) authors); + this.authors = authList; + } else if (authors instanceof List) { + this.authors = (List) authors; + } + + } + + private Map attachments; + @Override + public Map getAttachments() { + return attachments; + } + + /** + * Set the attachments of the weights if they exist + * @param attachments + * attachments of the model + */ + public void setAttachments(Object attachments) { + if (attachments instanceof Map) + this.attachments = (Map) attachments; + + } + + private String parent; + @Override + public String getParent() { + return parent; + } + + /** + * Set the parent of the weights in the case they exist + * @param parent + * parent weights of the model + */ + public void setParent(Object parent) { + if (parent instanceof String) + this.parent = (String) parent; + } + + private String architecture; + @Override + public String getArchitecture() { + return architecture; + } + + /** + * Set the path to the architecture of the weights in the case it exists + * @param architecture + * path to the architecture of the model + */ + public void setArchitecture(Object architecture) { + if (architecture instanceof String) + this.architecture = (String) architecture; + } + + private String architectureSha256; + @Override + public String getArchitectureSha256() { + return architectureSha256; + } + + /** + * Set the architecture Sha256 in the case it exists + * @param architectureSha256 + * architecture Sha256 of the model + */ + public void setArchitectureSha256(Object architectureSha256) { + if (architectureSha256 instanceof String) + this.architectureSha256 = (String) architectureSha256; + } + + @Override + public String getSourceFileName() { + if (source == null) + return source; + return new File(source).getName(); + } + + boolean gpu = false; + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public void supportGPU(boolean support) { + gpu = support; + } + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public boolean isSupportGPU() { + return gpu; + } +} diff --git a/src/main/java/deepimagej/tools/weights/TorchscriptWeights.java b/src/main/java/deepimagej/tools/weights/TorchscriptWeights.java new file mode 100644 index 00000000..63532ba0 --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/TorchscriptWeights.java @@ -0,0 +1,254 @@ +package deepimagej.tools.weights; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Class that contains the information for Torchscript weights. + * For more information about the parameters go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + * + */ +public class TorchscriptWeights implements WeightFormatInterface{ + + /** + * Crate an object that specifies Torchscript weights + * + * @param weights + * part of the yaml file that contains exclusively the + * information refering to the Torchscript weights + */ + public TorchscriptWeights(Map weights) { + weightsFormat = "torchscript"; + Set keys = weights.keySet(); + for (String k : keys) { + Object fieldElement = weights.get(k); + switch (k) + { + case "pytorch_version": + setTrainingVersion(fieldElement); + break; + case "source": + setSource(fieldElement); + break; + case "attachments": + setAttachments(fieldElement); + break; + case "authors": + setAuthors(fieldElement); + break; + case "parent": + setParent(fieldElement); + break; + case "sha256": + setSha256(fieldElement); + break; + case "architecture": + setArchitecture(fieldElement); + break; + case "architecture_sha256": + setArchitectureSha256(fieldElement); + break; + } + } + // TODO add fixed version if it is not shown because many models are missing + // the Pytorch version. Remove when they start appearing + if (trainingVersion == null) + trainingVersion = "1.13.1"; + } + + private String weightsFormat; + @Override + public String getWeightsFormat() { + return weightsFormat; + } + + private String trainingVersion; + @Override + public String getTrainingVersion() { + return trainingVersion; + } + + /** + * Set the training version for the weights + * specified in the yaml if it exists + * + * @param v + * training version of the weights + */ + public void setTrainingVersion(Object v) { + if (v instanceof String && !((String)v).contains("+") + && !((String)v).contains("cu") + && !((String)v).contains("cuda")) + this.trainingVersion = (String) v; + else if (v instanceof String && ((String)v).contains("+")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("+")).trim(); + else if (v instanceof String && ((String)v).contains("cuda")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cuda")).trim(); + else if (v instanceof String && ((String)v).contains("cu")) + this.trainingVersion = ((String) v).substring(0, ((String) v).indexOf("cu")).trim(); + else if (v instanceof Double) + this.trainingVersion = "" + v; + else if (v instanceof Float) + this.trainingVersion = "" + v; + else if (v instanceof Long) + this.trainingVersion = "" + v; + else if (v instanceof Integer) + this.trainingVersion = "" + v; + } + + private String sha256; + @Override + public String getSha256() { + return sha256; + } + + /** + * Set the SHA256 of the model from the parameters in the yaml + * + * @param s + * SHA256 of the model + */ + public void setSha256(Object s) { + if (s instanceof String) + sha256 = (String) s; + + } + + private String source; + @Override + public String getSource() { + return source; + } + + /** + * Set the source of the model from the parameters in the yaml + * + * @param s + * string from the yaml file containing the source, return only the + * name of the file inside the folder, not the whole path + */ + public void setSource(Object s) { + if (s instanceof String) + this.source = (String) s; + + } + + private List authors; + @Override + public List getAuthors() { + return authors; + } + + /** + * Set the authors of the model + * @param authors + * authors of the model + */ + public void setAuthors(Object authors) { + if (authors instanceof String) { + List authList = new ArrayList(); + authList.add((String) authors); + this.authors = authList; + } else if (authors instanceof List) { + this.authors = (List) authors; + } + + } + + private Map attachments; + @Override + public Map getAttachments() { + return attachments; + } + + /** + * Set the attachments of the weights if they exist + * @param attachments + * attachments of the model + */ + public void setAttachments(Object attachments) { + if (attachments instanceof Map) + this.attachments = (Map) attachments; + + } + + private String parent; + @Override + public String getParent() { + return parent; + } + + /** + * Set the parent of the weights in the case they exist + * @param parent + * parent weights of the model + */ + public void setParent(Object parent) { + if (parent instanceof String) + this.parent = (String) parent; + } + + private String architecture; + @Override + public String getArchitecture() { + return architecture; + } + + /** + * Set the path to the architecture of the weights in the case it exists + * @param architecture + * path to the architecture of the model + */ + public void setArchitecture(Object architecture) { + if (architecture instanceof String) + this.architecture = (String) architecture; + } + + private String architectureSha256; + @Override + public String getArchitectureSha256() { + return architectureSha256; + } + + /** + * Set the architecture Sha256 in the case it exists + * @param architectureSha256 + * architecture Sha256 of the model + */ + public void setArchitectureSha256(Object architectureSha256) { + if (architectureSha256 instanceof String) + this.architectureSha256 = (String) architectureSha256; + } + + @Override + public String getSourceFileName() { + if (source == null) + return source; + return new File(source).getName(); + } + + boolean gpu = false; + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public void supportGPU(boolean support) { + gpu = support; + } + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + @Override + public boolean isSupportGPU() { + return gpu; + } +} diff --git a/src/main/java/deepimagej/tools/weights/WeightFormatInterface.java b/src/main/java/deepimagej/tools/weights/WeightFormatInterface.java new file mode 100644 index 00000000..c09f1b66 --- /dev/null +++ b/src/main/java/deepimagej/tools/weights/WeightFormatInterface.java @@ -0,0 +1,109 @@ +package deepimagej.tools.weights; + +import java.util.List; +import java.util.Map; + +/** + * Interface that contains all the methods needed to create a + * Bioimage.io weight specification for any format. + * For more info go to: + * https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/weight_formats_spec_0_4.md + * + * @author Carlos Garcia Lopez de Haro + */ +public interface WeightFormatInterface { + + /** + * Retrieve the version of the framework used to train the + * weights + * + * @return version used to train the weights + */ + public String getTrainingVersion(); + + /** + * Format of the weights of the model. The supported weights by the Bioimage.io are: + * keras_hdf5, onnx, pytorch_state_dict, tensorflow_js, tensorflow_saved_model_bundle + * and torchscript + * + * @return the Deep Learning framework of the model + */ + public String getWeightsFormat(); + + /** + * SHA256 checksum of the source file + * + * @return SHA256 checksum of the source file + */ + public String getSha256(); + + /** + * REturn URL to the weights in the local machine + * @return URL to the weights in the local machine + */ + public String getSource(); + + /** + * REturn name of the source file for the weights. Does not include the path to the directory + * @return name of the file containing the weights + */ + public String getSourceFileName(); + + /** + * List of the authors that have trained the model in the case there is no + * parent model, or list of authors that converted the weights into this format + * + * @return the authors of the weights + */ + public List getAuthors(); + + /** REturn the attachments needed for the weights + * + * @return the attachements. Attachments consists of a dictionary + * of text keys and list values (that may contain any valid yaml) to + * additional, relevant files that are specific to the current + * weight format. A list of URIs can be listed under the files key + * to included additional files for generating the model package + */ + public Map getAttachments(); + + /** + * Returns the parent of the weights. + * @return the parent of the weights. This is the source weights + * used as input for converting the weights to this format. For + * example, if the weights were converted from the format + * pytorch_state_dict to torchscript, the parent is pytorch_state_dict. + * All weight entries except one (the initial set of weights + * resulting from training the model), need to have this field. + */ + public String getParent(); + + /** + * Source code of the model architecture that either points to a local + * implementation: : or the implementation in an available dependency: + * ... For example: + * my_function.py:MyImplementation or bioimageio.core.some_module.some_class_or_function. + * + * @return the archiecture used to train the weights + */ + public String getArchitecture(); + + /** + * SHA256 of the architecture + * @return the SHA256 of the architecture + */ + public String getArchitectureSha256(); + + /** + * Method to set whether the engine used for this weights supports GPU or not + * @return + */ + public void supportGPU(boolean support); + + /** + * Method to know whether the engine used for this weights supports GPU or not + * @return + */ + public boolean isSupportGPU(); +} diff --git a/src/main/resources/plugins.config b/src/main/resources/plugins.config index 6a9f0d89..41bdd6f7 100644 --- a/src/main/resources/plugins.config +++ b/src/main/resources/plugins.config @@ -4,5 +4,4 @@ Plugins>DeepImageJ, "DeepImageJ Run", DeepImageJ_Run Plugins>DeepImageJ, "DeepImageJ Install Model", DeepImageJ_InstallModel -Plugins>DeepImageJ, "DeepImageJ Validate", DeepImageJ_ImageValidation -Plugins>DeepImageJ, "DeepImageJ Build BundledModel", DeepImageJ_Build_BundledModel \ No newline at end of file +Plugins>DeepImageJ, "DeepImageJ Validate", DeepImageJ_ImageValidation \ No newline at end of file