Skip to content

Commit

Permalink
make sure the model is closed at the end
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 11, 2024
1 parent a007820 commit bf20909
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 28 deletions.
51 changes: 28 additions & 23 deletions src/main/java/DeepImageJ_Run.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import deepimagej.gui.Gui;
import deepimagej.tools.ImPlusRaiManager;
import ij.IJ;
import ij.ImageJ;
import ij.ImagePlus;
import ij.Macro;
import ij.WindowManager;
Expand All @@ -85,6 +84,8 @@ public class DeepImageJ_Run implements PlugIn {
private String outputFolder;
private String display;

private ModelDescriptor model;

/**
* Keys required to run deepImageJ with a macro
*/
Expand All @@ -94,15 +95,9 @@ public class DeepImageJ_Run implements PlugIn {
*/
final static String[] macroOptionalKeys = new String[] {"inputPath=", "outputFolder=", "displayOutput="};

private static final String MACRO_EXAMPLE = "run(\"DeepImageJ Run\", \"modelPath=LiveCellSegmentationBou "
+ "inputPath=/home/carlos/git/deepimagej-plugin/models/LiveCellSegmentationBou/sample_input_0.tif "
+ "outputFolder=/home/carlos/git/deepimagej-plugin/models/LiveCellSegmentationBou/ displayOutput=all\")";


static public void main(String args[]) {
//new DeepImageJ_Run().run("");
new ImageJ();
IJ.runMacro(MACRO_EXAMPLE);
new DeepImageJ_Run().run("");
}
@Override
public void run(String arg) {
Expand All @@ -113,7 +108,7 @@ public void run(String arg) {
boolean isHeadless = GraphicsEnvironment.isHeadless();
if (!isMacro) {
runGUI();
} else if (isMacro && !isHeadless) {
} else if (isMacro ) { //&& !isHeadless) {
runMacro();
} else if (isHeadless) {
runHeadless();
Expand All @@ -132,8 +127,11 @@ private void runGUI() {
}

/**
* Macro:
* "DeepImageJ Run"
* Macro example:
* run("DeepImageJ Run", "modelPath=/path/to/model/LiveCellSegmentationBou
* inputPath=/path/to/image/sample_input_0.tif
* outputFolder=/path/to/ouput/folder
* displayOutput=null")
*/
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runMacro() {
parseCommand();
Expand All @@ -144,32 +142,39 @@ private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeTy
throw new IllegalArgumentException("The provided output folder does not exist and cannot be created: " + this.inputFolder);

IjAdapter adapter = new IjAdapter();
Runner runner;

try {
ModelDescriptor model = ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
if (model.getInputTensors().size() > 1)
throw new IllegalArgumentException("Selected model requires more than one input, currently only models with 1 input"
+ " are supported.");
runner = Runner.create(model);
loadDescriptor();
} catch (ModelSpecsException | IOException e) {
e.printStackTrace();
return;
}
try (Runner runner = Runner.create(model)) {
runner.load();
if (this.inputFolder != null) {
executeOnPath(model, runner, adapter);
executeOnPath(runner, adapter);
} else {
executeOnImagePlus(model, runner, adapter);
executeOnImagePlus(runner, adapter);
}
} catch (ModelSpecsException | IOException | LoadModelException | RunModelException e) {
e.printStackTrace();
return;
}
}

private void loadDescriptor() throws FileNotFoundException, ModelSpecsException, IOException {
model = ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
if (model.getInputTensors().size() > 1)
throw new IllegalArgumentException("Selected model requires more than one input, currently only models with 1 input"
+ " are supported.");
}


private void runHeadless() {
// TODO not ready yet
}

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void executeOnPath(ModelDescriptor model, Runner runner, IjAdapter adapter) throws FileNotFoundException, ModelSpecsException, RunModelException, IOException {
void executeOnPath(Runner runner, IjAdapter adapter) throws FileNotFoundException, ModelSpecsException, RunModelException, IOException {
File ff = new File(this.inputFolder);
if (ff.isDirectory())
this.executeOnFolder(model, runner, adapter);
Expand Down Expand Up @@ -216,7 +221,7 @@ void executeOnFolder(ModelDescriptor model, Runner runner, IjAdapter adapter) th
if (this.outputFolder != null) {
IJ.saveAsTiff(im, this.outputFolder + File.separator + im.getTitle());
}
if (this.display.equals("all")) {
if (display != null && this.display.equals("all")) {
SwingUtilities.invokeLater(() -> im.show());
}
}
Expand All @@ -225,7 +230,7 @@ void executeOnFolder(ModelDescriptor model, Runner runner, IjAdapter adapter) th
}

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void executeOnImagePlus(ModelDescriptor model, Runner runner, IjAdapter adapter) throws FileNotFoundException, ModelSpecsException, RunModelException, IOException {
void executeOnImagePlus(Runner runner, IjAdapter adapter) throws FileNotFoundException, ModelSpecsException, RunModelException, IOException {
ImagePlus imp = WindowManager.getCurrentImage();
List<Tensor<T>> inputList = adapter.convertToInputTensors(null, model);
List<Tensor<R>> res = runner.run(inputList);
Expand Down
50 changes: 45 additions & 5 deletions src/test/java/TestRunEveryFramework.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,34 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.junit.Test;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.download.DownloadModel;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.engine.installation.FileDownloader;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.ZipUtils;

public class TestRunEveryFramework {

private List<String> modelPaths;

private static final File FIJI_DIR = new File("fiji");
private static final File ENGINES_DIR = new File(FIJI_DIR.getAbsolutePath(), "engines");
private static final File MODELS_DIR = new File(FIJI_DIR.getAbsolutePath(), "models");

private static final Map<String, String> MODELS;

Expand All @@ -41,15 +50,19 @@ public class TestRunEveryFramework {
}

@Test
public void testRun() {
public void testRun() throws InterruptedException, IOException {
downloadAndTrackFiji();
installEngines();
installModels();
//runModels();
}

private static void downloadAndTrackFiji() throws MalformedURLException {
private static void downloadAndTrackFiji() throws InterruptedException, IOException {
String url = FIJI_URL.get(PlatformDetection.getOs());
URL website = new URL(url);
DownloadModel.getFileSize(website);
long fijiSize = DownloadModel.getFileSize(website);
Path filePath = Paths.get(website.getPath()).getFileName();
File targetFile = new File(FIJI_DIR.getAbsoluteFile(), filePath.toString());
File targetFile = new File(FIJI_DIR.getAbsolutePath(), filePath.toString());

Thread parentThread = Thread.currentThread();
Thread dnwldthread = new Thread(() -> {
Expand All @@ -61,6 +74,15 @@ private static void downloadAndTrackFiji() throws MalformedURLException {
});
dnwldthread.start();

while (dnwldthread.isAlive()) {
Thread.sleep(300);
System.out.println("Download progress: " + (targetFile.length() / (double) fijiSize) + "%");
}

if (targetFile.length() != fijiSize)
throw new RuntimeException("Size of downloaded Fiji zip is different than the expected.");
ZipUtils.unzipFolder(targetFile.getAbsolutePath(), FIJI_DIR.getAbsolutePath());

}

private static void downloadFiji(URL website, File targetFile, Thread parentThread) throws InterruptedException, IOException {
Expand All @@ -74,4 +96,22 @@ private static void downloadFiji(URL website, File targetFile, Thread parentThre
}
conn.disconnect();
}

private static void installEngines() {
EngineInstall installer = EngineInstall.createInstaller(ENGINES_DIR.getAbsolutePath());
installer.basicEngineInstallation();
}

private void installModels() throws IOException, InterruptedException {
modelPaths = new ArrayList<String>();
List<String> dwnldModels = MODELS.entrySet().stream().map(ee -> ee.getValue()).collect(Collectors.toList());

BioimageioRepo br = BioimageioRepo.connect();
for (String mm : dwnldModels) {
String path = br.downloadModelByID(mm, MODELS_DIR.getAbsolutePath());
modelPaths.add(path);
}

}

}

0 comments on commit bf20909

Please sign in to comment.