From 5a5fe336f6f25dfe333b56236ecb13b2f312e914 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 24 Sep 2024 17:33:10 +0200 Subject: [PATCH] correct couple of silly errors --- .../v2/api030/Tensorflow2Interface.java | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api030/Tensorflow2Interface.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api030/Tensorflow2Interface.java index a15405b..fc2670c 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api030/Tensorflow2Interface.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api030/Tensorflow2Interface.java @@ -231,9 +231,17 @@ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelExceptio String source = descriptor.getWeights().gettAllSupportedWeightObjects().stream() .filter(ww -> ww.getFramework().equals(EngineInfo.getBioimageioTfKey())) .findFirst().get().getSource(); - source = DownloadModel.getFileNameFromURLString(source); - System.out.println("Unzipping model..."); - ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder); + if (new File(source).isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(new File(source).getAbsolutePath(), modelFolder); + } else if (new File(modelFolder, source).isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(new File(modelFolder, source).getAbsolutePath(), modelFolder); + } else { + source = DownloadModel.getFileNameFromURLString(source); + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder); + } } else { throw new LoadModelException("No model file was found in the model folder"); } @@ -258,7 +266,7 @@ void run(List> inputTensors, List> outputTensors) List inputListNames = new ArrayList(); List inTensors = new ArrayList(); int c = 0; - for (Tensor tt : inputTensors) { + for (Tensor tt : inputTensors) { inputListNames.add(tt.getName()); TType inT = TensorBuilder.build(tt); inTensors.add(inT); @@ -266,7 +274,7 @@ void run(List> inputTensors, List> outputTensors) runner.feed(inputName, inT); } c = 0; - for (Tensor tt : outputTensors) + for (Tensor tt : outputTensors) runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); // Run runner List resultPatchTensors = runner.run();