diff --git a/src/main/java/deepimagej/DeepImageJ.java b/src/main/java/deepimagej/DeepImageJ.java index 42230ee7..10cb6e26 100755 --- a/src/main/java/deepimagej/DeepImageJ.java +++ b/src/main/java/deepimagej/DeepImageJ.java @@ -329,7 +329,7 @@ public void writeParameters(TextArea info) { } info.append("Memory peak: " + params.memoryPeak + "\n"); info.append("Runtime: " + params.runtime + "\n"); - String modelSize = "-1"; + double modelSize = 0; String ptModelName = "weights-torchscript.pt"; String tfModelName = "tensorflow_saved_model_bundle.zip"; @@ -345,33 +345,26 @@ public void writeParameters(TextArea info) { } if (params.framework.equals("pytorch")) { - modelSize = "" + new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - info.append("Weights size: " + modelSize + " MB\n"); + modelSize = new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0); + info.append("Weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); } else if (params.framework.equals("tensorflow") && new File(this.getPath(), "variables").exists()) { - modelSize = "" + FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - info.append("Weights size: " + modelSize + " MB\n"); + modelSize = FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0); + info.append("Weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); } else if (params.framework.equals("tensorflow")) { - modelSize = "" + new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 2); - info.append("Zipped model size: " + modelSize + " MB\n"); + modelSize = new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0); + info.append("Zipped model size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); } else if (params.framework.equals("tensorflow/pytorch") && new File(this.getPath(), "variables").exists()) { - modelSize = "" + new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - info.append("Pytorch weights size: " + modelSize + " MB\n"); + modelSize = new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0); + info.append("Pytorch weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); - modelSize = "" + FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - info.append("Tensorflow weights size: " + modelSize + " MB\n"); + modelSize = FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0); + info.append("Tensorflow weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); } else if (params.framework.equals("tensorflow/pytorch")) { - modelSize = "" + new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - info.append("Pytorch weights size: " + modelSize + " MB\n"); + modelSize = new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0); + info.append("Pytorch weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); - modelSize = "" + new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0); - modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3); - info.append("Zipped Tensorflow model size: " + modelSize + " MB\n"); + modelSize = new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0); + info.append("Zipped Tensorflow model size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n"); } } @@ -390,9 +383,20 @@ public static String findNameFromSourceParam(String sourceName, String framework modelName = "tensorflow_saved_model_bundle.zip"; } else if (modelName.indexOf("/") != -1 && modelName.indexOf("/") < 2) { modelName = modelName.substring(modelName.indexOf("/") + 1); + } else if (checkValidUrl(modelName)) { + modelName = modelName.substring(modelName.lastIndexOf("/") + 1); } return modelName; } + + private static boolean checkValidUrl(String urlString) { + try { + new URL(urlString); + return true; + } catch (MalformedURLException e) { + return false; + } + } public boolean check(String path) { File dir = new File(path);