diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java b/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java index f213964..b0e3279 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java @@ -499,9 +499,62 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt return command; } + public static List getClassLocations(Class... classes) { + List locations = new ArrayList<>(); + for (Class clazz : classes) { + String location = getClassLocation(clazz); + if (location != null && !locations.contains(location)) { + locations.add(location); + } + } + return locations; + } + + private static String getClassLocation(Class clazz) { + try { + String classResource = clazz.getName().replace('.', '/') + ".class"; + URL location = clazz.getClassLoader().getResource(classResource); + System.out.println(location); + + if (location != null) { + String path = extractPath(location); + System.out.println(path); + if (path != null) { + File file = new File(path); + if (file.isFile()) { + // If it's a JAR file + return file.getAbsolutePath(); + } else { + // If it's a directory (for .class files) + return file.getParent(); + } + } + } + } catch (UnsupportedEncodingException e) { + e.printStackTrace(); + } + return null; + } + + private static String extractPath(URL url) throws UnsupportedEncodingException { + String path = url.getPath(); + if (path.startsWith("file:")) { + path = path.substring(5); + } + path = URLDecoder.decode(path, "UTF-8"); + + if (path.contains(".jar!")) { + path = path.substring(0, path.lastIndexOf(".jar!") + 4); + } + + return path; + } + private static String getCurrentClasspath() throws UnsupportedEncodingException { String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); + getClassLocations(DeepLearningEngineInterface.class); + System.out.println("AAAAAAAAAAAAAAAAAAAAAAAAAAAAA"); String imglib2Path = getPathFromClass(NativeType.class); String gsonPath = getPathFromClass(Gson.class); String jnaPath = getPathFromClass(com.sun.jna.Library.class);