Skip to content

Commit

Permalink
correct bug
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 18, 2023
1 parent db672b6 commit 8237d22
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions src/main/java/deepimagej/DeepImageJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ public void writeParameters(TextArea info) {
}
info.append("Memory peak: " + params.memoryPeak + "\n");
info.append("Runtime: " + params.runtime + "\n");
String modelSize = "-1";
double modelSize = 0;

String ptModelName = "weights-torchscript.pt";
String tfModelName = "tensorflow_saved_model_bundle.zip";
Expand All @@ -345,33 +345,26 @@ public void writeParameters(TextArea info) {
}

if (params.framework.equals("pytorch")) {
modelSize = "" + new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3);
info.append("Weights size: " + modelSize + " MB\n");
modelSize = new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0);
info.append("Weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");
} else if (params.framework.equals("tensorflow") && new File(this.getPath(), "variables").exists()) {
modelSize = "" + FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3);
info.append("Weights size: " + modelSize + " MB\n");
modelSize = FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0);
info.append("Weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");
} else if (params.framework.equals("tensorflow")) {
modelSize = "" + new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 2);
info.append("Zipped model size: " + modelSize + " MB\n");
modelSize = new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0);
info.append("Zipped model size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");
} else if (params.framework.equals("tensorflow/pytorch") && new File(this.getPath(), "variables").exists()) {
modelSize = "" + new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3);
info.append("Pytorch weights size: " + modelSize + " MB\n");
modelSize = new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0);
info.append("Pytorch weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");

modelSize = "" + FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3);
info.append("Tensorflow weights size: " + modelSize + " MB\n");
modelSize = FileTools.getFolderSize(this.getPath() + File.separator + "variables") / (1024 * 1024.0);
info.append("Tensorflow weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");
} else if (params.framework.equals("tensorflow/pytorch")) {
modelSize = "" + new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3);
info.append("Pytorch weights size: " + modelSize + " MB\n");
modelSize = new File(this.getPath() + File.separator + ptModelName).length() / (1024 * 1024.0);
info.append("Pytorch weights size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");

modelSize = "" + new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0);
modelSize = modelSize.substring(0, modelSize.lastIndexOf(".") + 3);
info.append("Zipped Tensorflow model size: " + modelSize + " MB\n");
modelSize = new File(this.getPath() + File.separator + tfModelName).length() / (1024 * 1024.0);
info.append("Zipped Tensorflow model size: " + Math.round(modelSize * 100.0) / 100.0 + " MB\n");
}

}
Expand All @@ -390,9 +383,20 @@ public static String findNameFromSourceParam(String sourceName, String framework
modelName = "tensorflow_saved_model_bundle.zip";
} else if (modelName.indexOf("/") != -1 && modelName.indexOf("/") < 2) {
modelName = modelName.substring(modelName.indexOf("/") + 1);
} else if (checkValidUrl(modelName)) {
modelName = modelName.substring(modelName.lastIndexOf("/") + 1);
}
return modelName;
}

private static boolean checkValidUrl(String urlString) {
try {
new URL(urlString);
return true;
} catch (MalformedURLException e) {
return false;
}
}

public boolean check(String path) {
File dir = new File(path);
Expand Down

0 comments on commit 8237d22

Please sign in to comment.