diff --git a/pom.xml b/pom.xml
index 6d802f0..597f457 100644
--- a/pom.xml
+++ b/pom.xml
@@ -125,7 +125,7 @@
sign,deploy-to-scijava
- 0.5.8
+ 0.5.10-SNAPSHOT
0.2.0
diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java
new file mode 100644
index 0000000..8c5c3a0
--- /dev/null
+++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java
@@ -0,0 +1,150 @@
+package io.bioimage.modelrunner.tensorflow.v2.api020;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Scanner;
+
+import io.bioimage.modelrunner.apposed.appose.Types;
+import io.bioimage.modelrunner.apposed.appose.Service.RequestType;
+import io.bioimage.modelrunner.apposed.appose.Service.ResponseType;
+
+public class JavaWorker {
+
+ private static LinkedHashMap tasks = new LinkedHashMap();
+
+ private final String uuid;
+
+ private final Tensorflow2Interface ti;
+
+ private boolean cancelRequested = false;
+
+ public static void main(String[] args) {
+
+ try(Scanner scanner = new Scanner(System.in)){
+ Tensorflow2Interface ti;
+ try {
+ ti = new Tensorflow2Interface(false);
+ } catch (IOException | URISyntaxException e) {
+ return;
+ }
+
+ while (true) {
+ String line;
+ try {
+ if (!scanner.hasNextLine()) break;
+ line = scanner.nextLine().trim();
+ } catch (Exception e) {
+ break;
+ }
+
+ if (line.isEmpty()) break;
+ Map request = Types.decode(line);
+ String uuid = (String) request.get("task");
+ String requestType = (String) request.get("requestType");
+
+ if (requestType.equals(RequestType.EXECUTE.toString())) {
+ String script = (String) request.get("script");
+ Map inputs = (Map) request.get("inputs");
+ JavaWorker task = new JavaWorker(uuid, ti);
+ tasks.put(uuid, task);
+ task.start(script, inputs);
+ } else if (requestType.equals(RequestType.CANCEL.toString())) {
+ JavaWorker task = (JavaWorker) tasks.get(uuid);
+ if (task == null) {
+ System.err.println("No such task: " + uuid);
+ continue;
+ }
+ task.cancelRequested = true;
+ } else {
+ break;
+ }
+ }
+ }
+
+ }
+
+ private JavaWorker(String uuid, Tensorflow2Interface ti) {
+ this.uuid = uuid;
+ this.ti = ti;
+ }
+
+ private void executeScript(String script, Map inputs) {
+ Map binding = new LinkedHashMap();
+ binding.put("task", this);
+ if (inputs != null)
+ binding.putAll(binding);
+
+ this.reportLaunch();
+ try {
+ if (script.equals("loadModel")) {
+ ti.loadModel((String) inputs.get("modelFolder"), null);
+ } else if (script.equals("inference")) {
+ ti.runFromShmas((List) inputs.get("inputs"), (List) inputs.get("outputs"));
+ } else if (script.equals("close")) {
+ ti.closeModel();
+ }
+ } catch(Exception ex) {
+ this.fail(Types.stackTrace(ex));
+ return;
+ }
+ this.reportCompletion();
+ }
+
+ private void start(String script, Map inputs) {
+ new Thread(() -> executeScript(script, inputs), "Appose-" + this.uuid).start();
+ }
+
+ private void reportLaunch() {
+ respond(ResponseType.LAUNCH, null);
+ }
+
+ private void reportCompletion() {
+ respond(ResponseType.COMPLETION, null);
+ }
+
+ private void update(String message, Integer current, Integer maximum) {
+ LinkedHashMap args = new LinkedHashMap();
+
+ if (message != null)
+ args.put("message", message);
+
+ if (current != null)
+ args.put("current", current);
+
+ if (maximum != null)
+ args.put("maximum", maximum);
+ this.respond(ResponseType.UPDATE, args);
+ }
+
+ private void respond(ResponseType responseType, Map args) {
+ Map response = new HashMap();
+ response.put("task", uuid);
+ response.put("responseType", responseType);
+ if (args != null)
+ response.putAll(args);
+ try {
+ System.out.println(Types.encode(response));
+ System.out.flush();
+ } catch(Exception ex) {
+ this.fail(Types.stackTrace(ex.getCause()));
+ }
+ }
+
+ private void cancel() {
+ this.respond(ResponseType.CANCELATION, null);
+ }
+
+ private void fail(String error) {
+ Map args = null;
+ if (error != null) {
+ args = new HashMap();
+ args.put("error", error);
+ }
+ respond(ResponseType.FAILURE, args);
+ }
+
+}
diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java
index 7de3713..008b989 100644
--- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java
+++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java
@@ -20,9 +20,15 @@
*/
package io.bioimage.modelrunner.tensorflow.v2.api020;
+import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
+import io.bioimage.modelrunner.apposed.appose.Service;
+import io.bioimage.modelrunner.apposed.appose.Types;
+import io.bioimage.modelrunner.apposed.appose.Service.Task;
+import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
+import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.download.DownloadModel;
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
import io.bioimage.modelrunner.engine.EngineInfo;
@@ -30,14 +36,19 @@
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
+import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer;
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2;
+import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.ZipUtils;
+import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
+import net.imglib2.util.Cast;
+import net.imglib2.util.Util;
import java.io.BufferedReader;
import java.io.File;
@@ -59,8 +70,10 @@
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
@@ -68,6 +81,7 @@
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.types.family.TType;
/**
* Class to that communicates with the dl-model runner, see
@@ -86,66 +100,23 @@
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
*/
public class Tensorflow2Interface implements DeepLearningEngineInterface {
-
- private static final String[] MODEL_TAGS = { "serve", "inference", "train",
- "eval", "gpu", "tpu" };
-
- private static final String[] TF_MODEL_TAGS = {
- "tf.saved_model.tag_constants.SERVING",
- "tf.saved_model.tag_constants.INFERENCE",
- "tf.saved_model.tag_constants.TRAINING",
- "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU",
- "tf.saved_model.tag_constants.TPU" };
-
- private static final String[] SIGNATURE_CONSTANTS = { "serving_default",
- "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs",
- "tensorflow/serving/predict", "outputs", "inputs",
- "tensorflow/serving/regress", "outputs", "train", "eval",
- "tensorflow/supervised/training", "tensorflow/supervised/eval" };
-
- private static final String[] TF_SIGNATURE_CONSTANTS = {
- "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY",
- "tf.saved_model.signature_constants.CLASSIFY_INPUTS",
- "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME",
- "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES",
- "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES",
- "tf.saved_model.signature_constants.PREDICT_INPUTS",
- "tf.saved_model.signature_constants.PREDICT_METHOD_NAME",
- "tf.saved_model.signature_constants.PREDICT_OUTPUTS",
- "tf.saved_model.signature_constants.REGRESS_INPUTS",
- "tf.saved_model.signature_constants.REGRESS_METHOD_NAME",
- "tf.saved_model.signature_constants.REGRESS_OUTPUTS",
- "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY",
- "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY",
- "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME",
- "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" };
-
- /**
- * Idetifier for the files that contain the data of the inputs
- */
- final private static String INPUT_FILE_TERMINATION = "_model_input";
-
- /**
- * Idetifier for the files that contain the data of the outputs
- */
- final private static String OUTPUT_FILE_TERMINATION = "_model_output";
- /**
- * Key for the inputs in the map that retrieves the file names for interprocess communication
- */
- final private static String INPUTS_MAP_KEY = "inputs";
- /**
- * Key for the outputs in the map that retrieves the file names for interprocess communication
- */
- final private static String OUTPUTS_MAP_KEY = "outputs";
- /**
- * File extension for the temporal files used for interprocessing
- */
- final private static String FILE_EXTENSION = ".dat";
/**
* Name without vesion of the JAR created for this library
*/
private static final String JAR_FILE_NAME = "dl-modelrunner-tensorflow-";
+ private static final String NAME_KEY = "name";
+ private static final String SHAPE_KEY = "shape";
+ private static final String DTYPE_KEY = "dtype";
+ private static final String IS_INPUT_KEY = "isInput";
+ private static final String MEM_NAME_KEY = "memoryName";
+
+ private List shmaInputList = new ArrayList();
+
+ private List shmaOutputList = new ArrayList();
+
+ private List shmaNamesList = new ArrayList();
+
/**
* The loaded Tensorflow 2 model
*/
@@ -158,26 +129,14 @@ public class Tensorflow2Interface implements DeepLearningEngineInterface {
* Whether the execution needs interprocessing (MacOS Interl) or not
*/
private boolean interprocessing = false;
- /**
- * TEmporary dir where to store temporary files
- */
- private String tmpDir;
/**
* Folde containing the model that is being executed
*/
private String modelFolder;
- /**
- * List of temporary files used for interprocessing communication
- */
- private List listTempFiles;
- /**
- * HashMap that maps tensor to the temporal file name for interprocessing
- */
- private HashMap tensorFilenameMap;
/**
* Process where the model is being loaded and executed
*/
- Process process;
+ Service runner;
/**
* TODO the interprocessing is executed for every OS
@@ -185,15 +144,11 @@ public class Tensorflow2Interface implements DeepLearningEngineInterface {
* executed is Windows or Mac or not to know if it is going to need interprocessing
* or not
* @throws IOException if the temporary dir is not found
+ * @throws URISyntaxException
*/
- public Tensorflow2Interface() throws IOException
+ public Tensorflow2Interface() throws IOException, URISyntaxException
{
- boolean isWin = PlatformDetection.isWindows();
- boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64);
- if (true || (isWin && isIntel)) {
- interprocessing = true;
- tmpDir = getTemporaryDir();
- }
+ this(true);
}
/**
@@ -203,19 +158,23 @@ public Tensorflow2Interface() throws IOException
* @param doInterprocessing
* whether to do interprocessing or not
* @throws IOException if the temp dir is not found
+ * @throws URISyntaxException
*/
- private Tensorflow2Interface(boolean doInterprocessing) throws IOException
+ protected Tensorflow2Interface(boolean doInterprocessing) throws IOException, URISyntaxException
{
- if (!doInterprocessing) {
- interprocessing = false;
- } else {
- boolean isWin = PlatformDetection.isMacOS();
- boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64);
- if (isWin && isIntel) {
- interprocessing = true;
- tmpDir = getTemporaryDir();
- }
- }
+ interprocessing = doInterprocessing;
+ if (this.interprocessing) {
+ runner = getRunner();
+ runner.debug((text) -> System.err.println(text));
+ }
+ }
+
+ private Service getRunner() throws IOException, URISyntaxException {
+ List args = getProcessCommandsWithoutArgs();
+ String[] argArr = new String[args.size()];
+ args.toArray(argArr);
+
+ return new Service(new File("."), argArr);
}
/**
@@ -231,8 +190,14 @@ public void loadModel(String modelFolder, String modelSource)
throws LoadModelException
{
this.modelFolder = modelFolder;
- if (interprocessing)
+ if (interprocessing) {
+ try {
+ launchModelLoadOnProcess();
+ } catch (IOException | InterruptedException e) {
+ throw new LoadModelException(Types.stackTrace(e));
+ }
return;
+ }
try {
checkModelUnzipped();
} catch (Exception e) {
@@ -249,6 +214,19 @@ public void loadModel(String modelFolder, String modelSource)
}
}
+ private void launchModelLoadOnProcess() throws IOException, InterruptedException {
+ HashMap args = new HashMap();
+ args.put("modelFolder", modelFolder);
+ Task task = runner.task("loadModel", args);
+ task.waitFor();
+ if (task.status == TaskStatus.CANCELED)
+ throw new RuntimeException();
+ else if (task.status == TaskStatus.FAILED)
+ throw new RuntimeException();
+ else if (task.status == TaskStatus.CRASHED)
+ throw new RuntimeException();
+ }
+
/**
* Check if an unzipped tensorflow model exists in the model folder,
* and if not look for it and unzip it
@@ -260,7 +238,7 @@ private void checkModelUnzipped() throws LoadModelException, IOException, Except
if (new File(modelFolder, "variables").isDirectory()
&& new File(modelFolder, "saved_model.pb").isFile())
return;
- unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME));
+ unzipTfWeights(ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME));
}
/**
@@ -301,7 +279,8 @@ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelExceptio
* and modifies the output list with the results obtained
*/
@Override
- public void run(List> inputTensors, List> outputTensors)
+ public & NativeType, R extends RealType & NativeType>
+ void run(List> inputTensors, List> outputTensors)
throws RunModelException
{
if (interprocessing) {
@@ -311,30 +290,65 @@ public void run(List> inputTensors, List> outputTensors)
Session session = model.session();
Session.Runner runner = session.runner();
List inputListNames = new ArrayList();
- List> inTensors =
- new ArrayList>();
+ List inTensors = new ArrayList();
int c = 0;
- for (Tensor tt : inputTensors) {
+ for (Tensor> tt : inputTensors) {
inputListNames.add(tt.getName());
- org.tensorflow.Tensor> inT = TensorBuilder.build(tt);
+ TType inT = TensorBuilder.build(tt);
inTensors.add(inT);
String inputName = getModelInputName(tt.getName(), c ++);
runner.feed(inputName, inT);
}
c = 0;
- for (Tensor tt : outputTensors)
+ for (Tensor> tt : outputTensors)
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
// Run runner
- List> resultPatchTensors = runner.run();
+ List resultPatchTensors = runner.run();
// Fill the agnostic output tensors list with data from the inference result
fillOutputTensors(resultPatchTensors, outputTensors);
// Close the remaining resources
- session.close();
- for (org.tensorflow.Tensor> tt : inTensors) {
+ for (TType tt : inTensors) {
tt.close();
}
- for (org.tensorflow.Tensor> tt : resultPatchTensors) {
+ for (org.tensorflow.Tensor tt : resultPatchTensors) {
+ tt.close();
+ }
+ }
+
+ protected void runFromShmas(List inputs, List outputs) throws IOException {
+ Session session = model.session();
+ Session.Runner runner = session.runner();
+
+ List inTensors = new ArrayList();
+ int c = 0;
+ for (String ee : inputs) {
+ Map decoded = Types.decode(ee);
+ SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
+ TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
+ if (PlatformDetection.isWindows()) shma.close();
+ inTensors.add(inT);
+ String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
+ runner.feed(inputName, inT);
+ }
+
+ c = 0;
+ for (String ee : outputs)
+ runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++));
+ // Run runner
+ List resultPatchTensors = runner.run();
+
+ // Fill the agnostic output tensors list with data from the inference result
+ c = 0;
+ for (String ee : outputs) {
+ Map decoded = Types.decode(ee);
+ ShmBuilder.build((TType) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY));
+ }
+ // Close the remaining resources
+ for (TType tt : inTensors) {
+ tt.close();
+ }
+ for (org.tensorflow.Tensor tt : resultPatchTensors) {
tt.close();
}
}
@@ -349,32 +363,110 @@ public void run(List> inputTensors, List> outputTensors)
* expected results of the model
* @throws RunModelException if there is any issue running the model
*/
- public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException {
- createTensorsForInterprocessing(inputTensors);
- createTensorsForInterprocessing(outputTensors);
+ public & NativeType, R extends RealType & NativeType>
+ void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException {
+ shmaInputList = new ArrayList();
+ shmaOutputList = new ArrayList();
+ List encIns = modifyForWinCmd(encodeInputs(inputTensors));
+ List encOuts = modifyForWinCmd(encodeOutputs(outputTensors));
+ LinkedHashMap args = new LinkedHashMap();
+ args.put("inputs", encIns);
+ args.put("outputs", encOuts);
+
try {
- List args = getProcessCommandsWithoutArgs();
- for (Tensor tensor : inputTensors) {args.add(getFilename4Tensor(tensor.getName()) + INPUT_FILE_TERMINATION);}
- for (Tensor tensor : outputTensors) {args.add(getFilename4Tensor(tensor.getName()) + OUTPUT_FILE_TERMINATION);}
- ProcessBuilder builder = new ProcessBuilder(args);
- builder.redirectOutput(ProcessBuilder.Redirect.INHERIT);
- builder.redirectError(ProcessBuilder.Redirect.INHERIT);
- process = builder.start();
- int result = process.waitFor();
- process.destroy();
- if (result != 0)
- throw new RunModelException("Error executing the Tensorflow 2 model in"
- + " a separate process. The process was not terminated correctly."
- + System.lineSeparator() + readProcessStringOutput(process));
- } catch (RunModelException e) {
- closeModel();
- throw e;
+ Task task = runner.task("inference", args);
+ task.waitFor();
+ if (task.status == TaskStatus.CANCELED)
+ throw new RuntimeException();
+ else if (task.status == TaskStatus.FAILED)
+ throw new RuntimeException();
+ else if (task.status == TaskStatus.CRASHED)
+ throw new RuntimeException();
+ for (int i = 0; i < outputTensors.size(); i ++) {
+ String name = (String) Types.decode(encOuts.get(i)).get(MEM_NAME_KEY);
+ SharedMemoryArray shm = shmaOutputList.stream()
+ .filter(ss -> ss.getName().equals(name)).findFirst().orElse(null);
+ if (shm == null) {
+ shm = SharedMemoryArray.read(name);
+ shmaOutputList.add(shm);
+ }
+ RandomAccessibleInterval> rai = shm.getSharedRAI();
+ outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
+ }
} catch (Exception e) {
- closeModel();
- throw new RunModelException(e.getCause().toString());
+ closeShmas();
+ if (e instanceof RunModelException)
+ throw (RunModelException) e;
+ throw new RunModelException(Types.stackTrace(e));
}
-
- retrieveInterprocessingTensors(outputTensors);
+ closeShmas();
+ }
+
+ private void closeShmas() {
+ shmaInputList.forEach(shm -> {
+ try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
+ });
+ shmaInputList = null;
+ shmaOutputList.forEach(shm -> {
+ try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
+ });
+ shmaOutputList = null;
+ }
+
+ private static List modifyForWinCmd(List ins){
+ if (!PlatformDetection.isWindows())
+ return ins;
+ List newIns = new ArrayList();
+ for (String ii : ins)
+ newIns.add("\"" + ii.replace("\"", "\\\"") + "\"");
+ return newIns;
+ }
+
+
+ private & NativeType> List encodeInputs(List> inputTensors) {
+ List encodedInputTensors = new ArrayList();
+ Gson gson = new Gson();
+ for (Tensor tt : inputTensors) {
+ SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
+ shmaInputList.add(shma);
+ HashMap map = new HashMap();
+ map.put(NAME_KEY, tt.getName());
+ map.put(SHAPE_KEY, tt.getShape());
+ map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData()));
+ map.put(IS_INPUT_KEY, true);
+ map.put(MEM_NAME_KEY, shma.getName());
+ encodedInputTensors.add(gson.toJson(map));
+ }
+ return encodedInputTensors;
+ }
+
+
+ private & NativeType>
+ List encodeOutputs(List> outputTensors) {
+ Gson gson = new Gson();
+ List encodedOutputTensors = new ArrayList();
+ for (Tensor> tt : outputTensors) {
+ HashMap map = new HashMap();
+ map.put(NAME_KEY, tt.getName());
+ map.put(IS_INPUT_KEY, false);
+ if (!tt.isEmpty()) {
+ map.put(SHAPE_KEY, tt.getShape());
+ map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData()));
+ SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
+ shmaOutputList.add(shma);
+ map.put(MEM_NAME_KEY, shma.getName());
+ } else if (PlatformDetection.isWindows()){
+ SharedMemoryArray shma = SharedMemoryArray.create(0);
+ shmaOutputList.add(shma);
+ map.put(MEM_NAME_KEY, shma.getName());
+ } else {
+ String memName = SharedMemoryArray.createShmName();
+ map.put(MEM_NAME_KEY, memName);
+ shmaNamesList.add(memName);
+ }
+ encodedOutputTensors.add(gson.toJson(map));
+ }
+ return encodedOutputTensors;
}
/**
@@ -386,9 +478,10 @@ public void runInterprocessing(List> inputTensors, List> out
* @throws RunModelException If the number of tensors expected is not the same
* as the number of Tensors outputed by the model
*/
- public static void fillOutputTensors(
+ public static & NativeType>
+ void fillOutputTensors(
List> outputTfTensors,
- List> outputTensors) throws RunModelException
+ List> outputTensors) throws RunModelException
{
if (outputTfTensors.size() != outputTensors.size())
throw new RunModelException(outputTfTensors.size(), outputTensors.size());
@@ -406,21 +499,31 @@ public static void fillOutputTensors(
*/
@Override
public void closeModel() {
+ if (this.interprocessing && runner != null) {
+ Task task;
+ try {
+ task = runner.task("close");
+ task.waitFor();
+ } catch (IOException | InterruptedException e) {
+ throw new RuntimeException(Types.stackTrace(e));
+ }
+ if (task.status == TaskStatus.CANCELED)
+ throw new RuntimeException();
+ else if (task.status == TaskStatus.FAILED)
+ throw new RuntimeException();
+ else if (task.status == TaskStatus.CRASHED)
+ throw new RuntimeException();
+ this.runner.close();
+ return;
+ } else if (this.interprocessing) {
+ return;
+ }
sig = null;
if (model != null) {
model.session().close();
model.close();
}
model = null;
- if (listTempFiles == null)
- return;
- for (File ff : listTempFiles) {
- if (ff.exists())
- ff.delete();
- }
- listTempFiles = null;
- if (process != null)
- process.destroyForcibly();
}
// TODO make only one
@@ -506,157 +609,8 @@ public static String getModelOutputName(String outputName, int i) {
* @throws RunModelException if there is any error running the model
*/
public static void main(String[] args) throws LoadModelException, IOException, RunModelException {
- Tensorflow2Interface tt = new Tensorflow2Interface(false);
-
- tt.loadModel("/home/carlos/Desktop/Fiji.app/models/model_03bioimageio", null);
- // Unpack the args needed
- if (args.length < 4)
- throw new IllegalArgumentException("Error exectuting Tensorflow 2, "
- + "at least 5 arguments are required:" + System.lineSeparator()
- + " - Folder where the model is located" + System.lineSeparator()
- + " - Temporary dir where the memory mapped files are located" + System.lineSeparator()
- + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator()
- + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator()
- + " - ...." + System.lineSeparator()
- + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator()
- + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator()
- + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator()
- + " - ...." + System.lineSeparator()
- + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator()
- );
- String modelFolder = args[0];
- if (!(new File(modelFolder).isDirectory())) {
- throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' "
- + "should be an existing directory containing a Tensorflow 2 model.");
- }
-
- Tensorflow2Interface tfInterface = new Tensorflow2Interface(false);
- tfInterface.tmpDir = args[1];
- if (!(new File(args[1]).isDirectory())) {
- throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' "
- + "should be an existing directory.");
- }
-
- tfInterface.loadModel(modelFolder, modelFolder);
-
- HashMap> map = tfInterface.getInputTensorsFileNames(args);
- List inputNames = map.get(INPUTS_MAP_KEY);
- List> inputList = inputNames.stream().map(n -> {
- try {
- return tfInterface.retrieveInterprocessingTensorsByName(n);
- } catch (RunModelException e) {
- return null;
- }
- }).collect(Collectors.toList());
- List outputNames = map.get(OUTPUTS_MAP_KEY);
- List> outputList = outputNames.stream().map(n -> {
- try {
- return tfInterface.retrieveInterprocessingTensorsByName(n);
- } catch (RunModelException e) {
- return null;
- }
- }).collect(Collectors.toList());
- tfInterface.run(inputList, outputList);
- tfInterface.createTensorsForInterprocessing(outputList);
- }
-
- /**
- * Get the name of the temporary file associated to the tensor name
- * @param name
- * name of the tensor
- * @return file name associated to the tensor
- */
- private String getFilename4Tensor(String name) {
- if (tensorFilenameMap == null)
- tensorFilenameMap = new HashMap();
- if (tensorFilenameMap.get(name) != null)
- return tensorFilenameMap.get(name);
- LocalDateTime now = LocalDateTime.now();
- DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmssSSS");
- String newName = name + "_" + now.format(formatter);
- tensorFilenameMap.put(name, newName);
- return tensorFilenameMap.get(name);
}
- /**
- * Create a temporary file for each of the tensors in the list to communicate with
- * the separate process in MacOS Intel and Windows systems
- * @param tensors
- * list of tensors to be sent
- * @throws RunModelException if there is any error converting the tensors
- */
- private void createTensorsForInterprocessing(List> tensors) throws RunModelException{
- if (this.listTempFiles == null)
- this.listTempFiles = new ArrayList();
- for (Tensor> tensor : tensors) {
- long lenFile = ImgLib2ToMappedBuffer.findTotalLengthFile(tensor);
- File ff = new File(tmpDir + File.separator + getFilename4Tensor(tensor.getName()) + FILE_EXTENSION);
- if (!ff.exists()) {
- ff.deleteOnExit();
- this.listTempFiles.add(ff);
- }
- try (RandomAccessFile rd =
- new RandomAccessFile(ff, "rw");
- FileChannel fc = rd.getChannel();) {
- MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile);
- ByteBuffer byteBuffer = mem.duplicate();
- ImgLib2ToMappedBuffer.build(tensor, byteBuffer);
- } catch (IOException e) {
- closeModel();
- throw new RunModelException(e.getCause().toString());
- }
- }
- }
-
- /**
- * Retrieves the data of the tensors contained in the input list from the output
- * generated by the independent process
- * @param tensors
- * list of tensors that are going to be filled
- * @throws RunModelException if there is any issue retrieving the data from the other process
- */
- private void retrieveInterprocessingTensors(List> tensors) throws RunModelException{
- for (Tensor> tensor : tensors) {
- try (RandomAccessFile rd =
- new RandomAccessFile(tmpDir + File.separator
- + this.getFilename4Tensor(tensor.getName()) + FILE_EXTENSION, "r");
- FileChannel fc = rd.getChannel();) {
- MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
- ByteBuffer byteBuffer = mem.duplicate();
- tensor.setData(MappedBufferToImgLib2.build(byteBuffer));
- } catch (IOException e) {
- closeModel();
- throw new RunModelException(e.getCause().toString());
- }
- }
- }
-
- /**
- * Create a tensor from the data contained in a file named as the parameter
- * provided as an input + the file extension {@link #FILE_EXTENSION}.
- * This file is produced by another process to communicate with the current process
- * @param
- * generic type of the tensor
- * @param name
- * name of the file without the extension ({@link #FILE_EXTENSION}).
- * @return a tensor created with the data in the file
- * @throws RunModelException if there is any problem retrieving the data and cerating the tensor
- */
- private < T extends RealType< T > & NativeType< T > > Tensor
- retrieveInterprocessingTensorsByName(String name) throws RunModelException {
- try (RandomAccessFile rd =
- new RandomAccessFile(tmpDir + File.separator
- + this.getFilename4Tensor(name) + FILE_EXTENSION, "r");
- FileChannel fc = rd.getChannel();) {
- MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
- ByteBuffer byteBuffer = mem.duplicate();
- return MappedBufferToImgLib2.buildTensor(byteBuffer);
- } catch (IOException e) {
- closeModel();
- throw new RunModelException(e.getCause().toString());
- }
- }
-
/**
* if java bin dir contains any special char, surround it by double quotes
* @param javaBin
@@ -684,12 +638,7 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt
String javaHome = System.getProperty("java.home");
String javaBin = javaHome + File.separator + "bin" + File.separator + "java";
- String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
- String imglib2Path = getPathFromClass(NativeType.class);
- if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
- && !modelrunnerPath.contains(File.pathSeparator)))
- modelrunnerPath = System.getProperty("java.class.path");
- String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
+ String classpath = getCurrentClasspath();
ProtectionDomain protectionDomain = Tensorflow2Interface.class.getProtectionDomain();
String codeSource = protectionDomain.getCodeSource().getLocation().getPath();
String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString());
@@ -699,17 +648,34 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt
continue;
classpath += ff.getAbsolutePath() + File.pathSeparator;
}
- String className = Tensorflow2Interface.class.getName();
+ String className = JavaWorker.class.getName();
List command = new LinkedList();
command.add(padSpecialJavaBin(javaBin));
command.add("-cp");
command.add(classpath);
command.add(className);
- command.add(modelFolder);
- command.add(this.tmpDir);
return command;
}
+ private static String getCurrentClasspath() throws UnsupportedEncodingException {
+
+ String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
+ String imglib2Path = getPathFromClass(NativeType.class);
+ String gsonPath = getPathFromClass(Gson.class);
+ String jnaPath = getPathFromClass(com.sun.jna.Library.class);
+ String jnaPlatformPath = getPathFromClass(com.sun.jna.platform.FileUtils.class);
+ if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
+ && !modelrunnerPath.contains(File.pathSeparator)))
+ modelrunnerPath = System.getProperty("java.class.path");
+ modelrunnerPath = System.getProperty("java.class.path");
+ String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
+ classpath = classpath + gsonPath + File.pathSeparator;
+ classpath = classpath + jnaPath + File.pathSeparator;
+ classpath = classpath + jnaPlatformPath + File.pathSeparator;
+
+ return classpath;
+ }
+
/**
* Method that gets the path to the JAR from where a specific class is being loaded
* @param clazz
@@ -740,41 +706,6 @@ private static String getPathFromClass(Class> clazz) throws UnsupportedEncodin
return path;
}
- /**
- * Get temporary directory to perform the interprocessing communication in MacOSX
- * intel and Windows
- * @return the tmp dir
- * @throws IOException if the files cannot be written in any of the temp dirs
- */
- private static String getTemporaryDir() throws IOException {
- String tmpDir;
- String enginesDir = getEnginesDir();
- if (enginesDir != null && Files.isWritable(Paths.get(enginesDir))) {
- tmpDir = enginesDir + File.separator + "temp";
- if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs()))
- tmpDir = enginesDir;
- } else if (System.getenv("temp") != null
- && Files.isWritable(Paths.get(System.getenv("temp")))) {
- return System.getenv("temp");
- } else if (System.getenv("TEMP") != null
- && Files.isWritable(Paths.get(System.getenv("TEMP")))) {
- return System.getenv("TEMP");
- } else if (System.getenv("tmp") != null
- && Files.isWritable(Paths.get(System.getenv("tmp")))) {
- return System.getenv("tmp");
- } else if (System.getenv("TMP") != null
- && Files.isWritable(Paths.get(System.getenv("TMP")))) {
- return System.getenv("TMP");
- } else if (System.getProperty("java.io.tmpdir") != null
- && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) {
- return System.getProperty("java.io.tmpdir");
- } else {
- throw new IOException("Unable to find temporal directory with writting rights. "
- + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'.");
- }
- return tmpDir;
- }
-
/**
* GEt the directory where the TF2 engine is located if a temporary dir is not found
* @return directory of the engines
@@ -806,64 +737,4 @@ private static String getEnginesDir() {
}
return new File(dir).getParent();
}
-
- /**
- * Retrieve the file names used for interprocess communication
- * @param args
- * args provided to the main method
- * @return a map with a list of input and output names
- */
- private HashMap> getInputTensorsFileNames(String[] args) {
- List inputNames = new ArrayList();
- List outputNames = new ArrayList();
- if (this.tensorFilenameMap == null)
- this.tensorFilenameMap = new HashMap();
- for (int i = 2; i < args.length; i ++) {
- if (args[i].endsWith(INPUT_FILE_TERMINATION)) {
- String nameWTimestamp = args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length());
- String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_"));
- inputNames.add(onlyName);
- tensorFilenameMap.put(onlyName, nameWTimestamp);
- } else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) {
- String nameWTimestamp = args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length());
- String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_"));
- outputNames.add(onlyName);
- tensorFilenameMap.put(onlyName, nameWTimestamp);
-
- }
- }
- if (inputNames.size() == 0)
- throw new IllegalArgumentException("The args to the main method of '"
- + Tensorflow2Interface.class.toString() + "' should contain at "
- + "least one input, defined as ' + '" + INPUT_FILE_TERMINATION + "'.");
- if (outputNames.size() == 0)
- throw new IllegalArgumentException("The args to the main method of '"
- + Tensorflow2Interface.class.toString() + "' should contain at "
- + "least one output, defined as ' + '" + OUTPUT_FILE_TERMINATION + "'.");
- HashMap> map = new HashMap>();
- map.put(INPUTS_MAP_KEY, inputNames);
- map.put(OUTPUTS_MAP_KEY, outputNames);
- return map;
- }
-
- /**
- * MEthod to obtain the String output of the process in case something goes wrong
- * @param process
- * the process that executed the TF2 model
- * @return the String output that we would have seen on the terminal
- * @throws IOException if the output of the terminal cannot be seen
- */
- private static String readProcessStringOutput(Process process) throws IOException {
- BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
- BufferedReader bufferedErrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
- String text = "";
- String line;
- while ((line = bufferedErrReader.readLine()) != null) {
- text += line + System.lineSeparator();
- }
- while ((line = bufferedReader.readLine()) != null) {
- text += line + System.lineSeparator();
- }
- return text;
- }
}
diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java
new file mode 100644
index 0000000..64adaa7
--- /dev/null
+++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java
@@ -0,0 +1,869 @@
+/*-
+ * #%L
+ * This project complements the DL-model runner acting as the engine that works loading models
+ * and making inference with Java 0.2.0 API for Tensorflow 2.
+ * %%
+ * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers.
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * #L%
+ */
+package io.bioimage.modelrunner.tensorflow.v2.api020;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+
+import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
+import io.bioimage.modelrunner.bioimageio.download.DownloadModel;
+import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
+import io.bioimage.modelrunner.engine.EngineInfo;
+import io.bioimage.modelrunner.exceptions.LoadModelException;
+import io.bioimage.modelrunner.exceptions.RunModelException;
+import io.bioimage.modelrunner.system.PlatformDetection;
+import io.bioimage.modelrunner.tensor.Tensor;
+import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder;
+import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder;
+import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer;
+import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2;
+import io.bioimage.modelrunner.utils.Constants;
+import io.bioimage.modelrunner.utils.ZipUtils;
+import net.imglib2.type.NativeType;
+import net.imglib2.type.numeric.RealType;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.RandomAccessFile;
+import java.io.UnsupportedEncodingException;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.net.URLDecoder;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.security.ProtectionDomain;
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.proto.framework.MetaGraphDef;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+
+/**
+ * Class to that communicates with the dl-model runner, see
+ * @see dlmodelrunner
+ * to execute Tensorflow 2 models. This class is compatible with TF2 Java API 0.2.0.
+ * This class implements the interface {@link DeepLearningEngineInterface} to get the
+ * agnostic {@link io.bioimage.modelrunner.tensor.Tensor}, convert them into
+ * {@link org.tensorflow.Tensor}, execute a Tensorflow 2 Deep Learning model on them and
+ * convert the results back to {@link io.bioimage.modelrunner.tensor.Tensor} to send them
+ * to the main program in an agnostic manner.
+ *
+ * {@link ImgLib2Builder}. Creates ImgLib2 images for the backend
+ * of {@link io.bioimage.modelrunner.tensor.Tensor} from {@link org.tensorflow.Tensor}
+ * {@link TensorBuilder}. Converts {@link io.bioimage.modelrunner.tensor.Tensor} into {@link org.tensorflow.Tensor}
+ *
+ * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
+ */
+public class Tensorflow2Interface2_old implements DeepLearningEngineInterface {
+
+ private static final String[] MODEL_TAGS = { "serve", "inference", "train",
+ "eval", "gpu", "tpu" };
+
+ private static final String[] TF_MODEL_TAGS = {
+ "tf.saved_model.tag_constants.SERVING",
+ "tf.saved_model.tag_constants.INFERENCE",
+ "tf.saved_model.tag_constants.TRAINING",
+ "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU",
+ "tf.saved_model.tag_constants.TPU" };
+
+ private static final String[] SIGNATURE_CONSTANTS = { "serving_default",
+ "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs",
+ "tensorflow/serving/predict", "outputs", "inputs",
+ "tensorflow/serving/regress", "outputs", "train", "eval",
+ "tensorflow/supervised/training", "tensorflow/supervised/eval" };
+
+ private static final String[] TF_SIGNATURE_CONSTANTS = {
+ "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY",
+ "tf.saved_model.signature_constants.CLASSIFY_INPUTS",
+ "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME",
+ "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES",
+ "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES",
+ "tf.saved_model.signature_constants.PREDICT_INPUTS",
+ "tf.saved_model.signature_constants.PREDICT_METHOD_NAME",
+ "tf.saved_model.signature_constants.PREDICT_OUTPUTS",
+ "tf.saved_model.signature_constants.REGRESS_INPUTS",
+ "tf.saved_model.signature_constants.REGRESS_METHOD_NAME",
+ "tf.saved_model.signature_constants.REGRESS_OUTPUTS",
+ "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY",
+ "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY",
+ "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME",
+ "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" };
+
+ /**
+ * Idetifier for the files that contain the data of the inputs
+ */
+ final private static String INPUT_FILE_TERMINATION = "_model_input";
+
+ /**
+ * Idetifier for the files that contain the data of the outputs
+ */
+ final private static String OUTPUT_FILE_TERMINATION = "_model_output";
+ /**
+ * Key for the inputs in the map that retrieves the file names for interprocess communication
+ */
+ final private static String INPUTS_MAP_KEY = "inputs";
+ /**
+ * Key for the outputs in the map that retrieves the file names for interprocess communication
+ */
+ final private static String OUTPUTS_MAP_KEY = "outputs";
+ /**
+ * File extension for the temporal files used for interprocessing
+ */
+ final private static String FILE_EXTENSION = ".dat";
+ /**
+ * Name without vesion of the JAR created for this library
+ */
+ private static final String JAR_FILE_NAME = "dl-modelrunner-tensorflow-";
+
+ /**
+ * The loaded Tensorflow 2 model
+ */
+ private static SavedModelBundle model;
+ /**
+ * Internal object of the Tensorflow model
+ */
+ private static SignatureDef sig;
+ /**
+ * Whether the execution needs interprocessing (MacOS Interl) or not
+ */
+ private boolean interprocessing = false;
+ /**
+ * TEmporary dir where to store temporary files
+ */
+ private String tmpDir;
+ /**
+ * Folde containing the model that is being executed
+ */
+ private String modelFolder;
+ /**
+ * List of temporary files used for interprocessing communication
+ */
+ private List listTempFiles;
+ /**
+ * HashMap that maps tensor to the temporal file name for interprocessing
+ */
+ private HashMap tensorFilenameMap;
+ /**
+ * Process where the model is being loaded and executed
+ */
+ Process process;
+
+ /**
+ * TODO the interprocessing is executed for every OS
+ * Constructor that detects whether the operating system where it is being
+ * executed is Windows or Mac or not to know if it is going to need interprocessing
+ * or not
+ * @throws IOException if the temporary dir is not found
+ */
+ public Tensorflow2Interface2_old() throws IOException
+ {
+ boolean isWin = PlatformDetection.isWindows();
+ boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64);
+ if (true || (isWin && isIntel)) {
+ interprocessing = true;
+ tmpDir = getTemporaryDir();
+ }
+ }
+
+ /**
+ * Private constructor that can only be launched from the class to create a separate
+ * process to avoid the conflicts that occur in the same process between TF2 and TF1/Pytorch
+ * in Windows and Mac
+ * @param doInterprocessing
+ * whether to do interprocessing or not
+ * @throws IOException if the temp dir is not found
+ */
+ private Tensorflow2Interface2_old(boolean doInterprocessing) throws IOException
+ {
+ if (!doInterprocessing) {
+ interprocessing = false;
+ } else {
+ boolean isWin = PlatformDetection.isMacOS();
+ boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64);
+ if (isWin && isIntel) {
+ interprocessing = true;
+ tmpDir = getTemporaryDir();
+ }
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Load a Tensorflow 2 model. If the machine where the code is
+ * being executed is a MacOS Intel or Windows, the model will be loaded in
+ * a separate process each time the method {@link #run(List, List)}
+ * is called
+ */
+ @Override
+ public void loadModel(String modelFolder, String modelSource)
+ throws LoadModelException
+ {
+ this.modelFolder = modelFolder;
+ if (interprocessing)
+ return;
+ try {
+ checkModelUnzipped();
+ } catch (Exception e) {
+ throw new LoadModelException(e.toString());
+ }
+ model = SavedModelBundle.load(this.modelFolder, "serve");
+ byte[] byteGraph = model.metaGraphDef().toByteArray();
+ try {
+ sig = MetaGraphDef.parseFrom(byteGraph).getSignatureDefOrThrow(
+ "serving_default");
+ }
+ catch (InvalidProtocolBufferException e) {
+ System.out.println("Invalid graph");
+ }
+ }
+
+ /**
+ * Check if an unzipped tensorflow model exists in the model folder,
+ * and if not look for it and unzip it
+ * @throws LoadModelException if no model is found
+ * @throws IOException if there is any error unzipping the model
+ * @throws Exception if there is any error related to model packaging
+ */
+ private void checkModelUnzipped() throws LoadModelException, IOException, Exception {
+ if (new File(modelFolder, "variables").isDirectory()
+ && new File(modelFolder, "saved_model.pb").isFile())
+ return;
+ unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME));
+ }
+
+ /**
+ * Method that unzips the tensorflow model zip into the variables
+ * folder and .pb file, if they are saved in a zip
+ * @throws LoadModelException if not zip file is found
+ * @throws IOException if there is any error unzipping
+ */
+ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelException, IOException {
+ if (new File(modelFolder, "tf_weights.zip").isFile()) {
+ System.out.println("Unzipping model...");
+ ZipUtils.unzipFolder(modelFolder + File.separator + "tf_weights.zip", modelFolder);
+ } else if ( descriptor.getWeights().getAllSuportedWeightNames()
+ .contains(EngineInfo.getBioimageioTfKey()) ) {
+ String source = descriptor.getWeights().gettAllSupportedWeightObjects().stream()
+ .filter(ww -> ww.getFramework().equals(EngineInfo.getBioimageioTfKey()))
+ .findFirst().get().getSource();
+ if (new File(source).isFile()) {
+ System.out.println("Unzipping model...");
+ ZipUtils.unzipFolder(new File(source).getAbsolutePath(), modelFolder);
+ } else if (new File(modelFolder, source).isFile()) {
+ System.out.println("Unzipping model...");
+ ZipUtils.unzipFolder(new File(modelFolder, source).getAbsolutePath(), modelFolder);
+ } else {
+ source = DownloadModel.getFileNameFromURLString(source);
+ System.out.println("Unzipping model...");
+ ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder);
+ }
+ } else {
+ throw new LoadModelException("No model file was found in the model folder");
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Run a Tensorflow2 model on the data provided by the {@link Tensor} input list
+ * and modifies the output list with the results obtained
+ */
+ @Override
+ public void run(List> inputTensors, List> outputTensors)
+ throws RunModelException
+ {
+ if (interprocessing) {
+ runInterprocessing(inputTensors, outputTensors);
+ return;
+ }
+ Session session = model.session();
+ Session.Runner runner = session.runner();
+ List inputListNames = new ArrayList();
+ List> inTensors =
+ new ArrayList>();
+ int c = 0;
+ for (Tensor tt : inputTensors) {
+ inputListNames.add(tt.getName());
+ org.tensorflow.Tensor> inT = TensorBuilder.build(tt);
+ inTensors.add(inT);
+ String inputName = getModelInputName(tt.getName(), c ++);
+ runner.feed(inputName, inT);
+ }
+ c = 0;
+ for (Tensor tt : outputTensors)
+ runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
+ // Run runner
+ List> resultPatchTensors = runner.run();
+
+ // Fill the agnostic output tensors list with data from the inference result
+ fillOutputTensors(resultPatchTensors, outputTensors);
+ // Close the remaining resources
+ session.close();
+ for (org.tensorflow.Tensor> tt : inTensors) {
+ tt.close();
+ }
+ for (org.tensorflow.Tensor> tt : resultPatchTensors) {
+ tt.close();
+ }
+ }
+
+ /**
+ * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
+ * to create another process, communicate the model info and tensors to the other
+ * process and then retrieve the results of the other process
+ * @param inputTensors
+ * tensors that are going to be run on the model
+ * @param outputTensors
+ * expected results of the model
+ * @throws RunModelException if there is any issue running the model
+ */
+ public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException {
+ createTensorsForInterprocessing(inputTensors);
+ createTensorsForInterprocessing(outputTensors);
+ try {
+ List args = getProcessCommandsWithoutArgs();
+ for (Tensor tensor : inputTensors) {args.add(getFilename4Tensor(tensor.getName()) + INPUT_FILE_TERMINATION);}
+ for (Tensor tensor : outputTensors) {args.add(getFilename4Tensor(tensor.getName()) + OUTPUT_FILE_TERMINATION);}
+ ProcessBuilder builder = new ProcessBuilder(args);
+ builder.redirectOutput(ProcessBuilder.Redirect.INHERIT);
+ builder.redirectError(ProcessBuilder.Redirect.INHERIT);
+ process = builder.start();
+ int result = process.waitFor();
+ process.destroy();
+ if (result != 0)
+ throw new RunModelException("Error executing the Tensorflow 2 model in"
+ + " a separate process. The process was not terminated correctly."
+ + System.lineSeparator() + readProcessStringOutput(process));
+ } catch (RunModelException e) {
+ closeModel();
+ throw e;
+ } catch (Exception e) {
+ closeModel();
+ throw new RunModelException(e.getCause().toString());
+ }
+
+ retrieveInterprocessingTensors(outputTensors);
+ }
+
+ /**
+ * Create the list a list of output tensors agnostic to the Deep Learning
+ * engine that can be readable by the dl-modelrunner
+ *
+ * @param outputTfTensors an List containing dl-modelrunner tensors
+ * @param outputTensors the names given to the tensors by the model
+ * @throws RunModelException If the number of tensors expected is not the same
+ * as the number of Tensors outputed by the model
+ */
+ public static void fillOutputTensors(
+ List> outputTfTensors,
+ List> outputTensors) throws RunModelException
+ {
+ if (outputTfTensors.size() != outputTensors.size())
+ throw new RunModelException(outputTfTensors.size(), outputTensors.size());
+ for (int i = 0; i < outputTfTensors.size(); i++) {
+ outputTensors.get(i).setData(ImgLib2Builder.build(outputTfTensors.get(i)));
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Close the Tensorflow 2 {@link #model} and {@link #sig}. For
+ * MacOS Intel and Windows systems it also deletes the temporary files created to
+ * communicate with the other process
+ */
+ @Override
+ public void closeModel() {
+ sig = null;
+ if (model != null) {
+ model.session().close();
+ model.close();
+ }
+ model = null;
+ if (listTempFiles == null)
+ return;
+ for (File ff : listTempFiles) {
+ if (ff.exists())
+ ff.delete();
+ }
+ listTempFiles = null;
+ if (process != null)
+ process.destroyForcibly();
+ }
+
+ // TODO make only one
+ /**
+ * Retrieves the readable input name from the graph signature definition given
+ * the signature input name.
+ *
+ * @param inputName Signature input name.
+ * @param i position of the input of interest in the list of inputs
+ * @return The readable input name.
+ */
+ public static String getModelInputName(String inputName, int i) {
+ TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null);
+ if (inputInfo == null) {
+ inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i);
+ }
+ if (inputInfo != null) {
+ String modelInputName = inputInfo.getName();
+ if (modelInputName != null) {
+ if (modelInputName.endsWith(":0")) {
+ return modelInputName.substring(0, modelInputName.length() - 2);
+ }
+ else {
+ return modelInputName;
+ }
+ }
+ else {
+ return inputName;
+ }
+ }
+ return inputName;
+ }
+
+ /**
+ * Retrieves the readable output name from the graph signature definition
+ * given the signature output name.
+ *
+ * @param outputName Signature output name.
+ * @param i position of the input of interest in the list of inputs
+ * @return The readable output name.
+ */
+ public static String getModelOutputName(String outputName, int i) {
+ TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null);
+ if (outputInfo == null) {
+ outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i);
+ }
+ if (outputInfo != null) {
+ String modelOutputName = outputInfo.getName();
+ if (modelOutputName.endsWith(":0")) {
+ return modelOutputName.substring(0, modelOutputName.length() - 2);
+ }
+ else {
+ return modelOutputName;
+ }
+ }
+ else {
+ return outputName;
+ }
+ }
+
+
+ /**
+ * Methods to run interprocessing and bypass the errors that occur in MacOS intel
+ * with the compatibility between TF2 and TF1/Pytorch
+ * This method checks that the arguments are correct, retrieves the input and output
+ * tensors, loads the model, makes inference with it and finally sends the tensors
+ * to the original process
+ *
+ * @param args
+ * arguments of the program:
+ * - Path to the model folder
+ * - Path to a temporary dir
+ * - Name of the input 0
+ * - Name of the input 1
+ * - ...
+ * - Name of the output n
+ * - Name of the output 0
+ * - Name of the output 1
+ * - ...
+ * - Name of the output n
+ * @throws LoadModelException if there is any error loading the model
+ * @throws IOException if there is any error reading or writing any file or with the paths
+ * @throws RunModelException if there is any error running the model
+ */
+ public static void main(String[] args) throws LoadModelException, IOException, RunModelException {
+ Tensorflow2Interface2_old tt = new Tensorflow2Interface2_old(false);
+
+ tt.loadModel("/home/carlos/Desktop/Fiji.app/models/model_03bioimageio", null);
+ // Unpack the args needed
+ if (args.length < 4)
+ throw new IllegalArgumentException("Error exectuting Tensorflow 2, "
+ + "at least 5 arguments are required:" + System.lineSeparator()
+ + " - Folder where the model is located" + System.lineSeparator()
+ + " - Temporary dir where the memory mapped files are located" + System.lineSeparator()
+ + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator()
+ + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator()
+ + " - ...." + System.lineSeparator()
+ + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator()
+ + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator()
+ + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator()
+ + " - ...." + System.lineSeparator()
+ + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator()
+ );
+ String modelFolder = args[0];
+ if (!(new File(modelFolder).isDirectory())) {
+ throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' "
+ + "should be an existing directory containing a Tensorflow 2 model.");
+ }
+
+ Tensorflow2Interface2_old tfInterface = new Tensorflow2Interface2_old(false);
+ tfInterface.tmpDir = args[1];
+ if (!(new File(args[1]).isDirectory())) {
+ throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' "
+ + "should be an existing directory.");
+ }
+
+ tfInterface.loadModel(modelFolder, modelFolder);
+
+ HashMap> map = tfInterface.getInputTensorsFileNames(args);
+ List inputNames = map.get(INPUTS_MAP_KEY);
+ List> inputList = inputNames.stream().map(n -> {
+ try {
+ return tfInterface.retrieveInterprocessingTensorsByName(n);
+ } catch (RunModelException e) {
+ return null;
+ }
+ }).collect(Collectors.toList());
+ List outputNames = map.get(OUTPUTS_MAP_KEY);
+ List> outputList = outputNames.stream().map(n -> {
+ try {
+ return tfInterface.retrieveInterprocessingTensorsByName(n);
+ } catch (RunModelException e) {
+ return null;
+ }
+ }).collect(Collectors.toList());
+ tfInterface.run(inputList, outputList);
+ tfInterface.createTensorsForInterprocessing(outputList);
+ }
+
+ /**
+ * Get the name of the temporary file associated to the tensor name
+ * @param name
+ * name of the tensor
+ * @return file name associated to the tensor
+ */
+ private String getFilename4Tensor(String name) {
+ if (tensorFilenameMap == null)
+ tensorFilenameMap = new HashMap();
+ if (tensorFilenameMap.get(name) != null)
+ return tensorFilenameMap.get(name);
+ LocalDateTime now = LocalDateTime.now();
+ DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmssSSS");
+ String newName = name + "_" + now.format(formatter);
+ tensorFilenameMap.put(name, newName);
+ return tensorFilenameMap.get(name);
+ }
+
+ /**
+ * Create a temporary file for each of the tensors in the list to communicate with
+ * the separate process in MacOS Intel and Windows systems
+ * @param tensors
+ * list of tensors to be sent
+ * @throws RunModelException if there is any error converting the tensors
+ */
+ private void createTensorsForInterprocessing(List> tensors) throws RunModelException{
+ if (this.listTempFiles == null)
+ this.listTempFiles = new ArrayList();
+ for (Tensor> tensor : tensors) {
+ long lenFile = ImgLib2ToMappedBuffer.findTotalLengthFile(tensor);
+ File ff = new File(tmpDir + File.separator + getFilename4Tensor(tensor.getName()) + FILE_EXTENSION);
+ if (!ff.exists()) {
+ ff.deleteOnExit();
+ this.listTempFiles.add(ff);
+ }
+ try (RandomAccessFile rd =
+ new RandomAccessFile(ff, "rw");
+ FileChannel fc = rd.getChannel();) {
+ MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile);
+ ByteBuffer byteBuffer = mem.duplicate();
+ ImgLib2ToMappedBuffer.build(tensor, byteBuffer);
+ } catch (IOException e) {
+ closeModel();
+ throw new RunModelException(e.getCause().toString());
+ }
+ }
+ }
+
+ /**
+ * Retrieves the data of the tensors contained in the input list from the output
+ * generated by the independent process
+ * @param tensors
+ * list of tensors that are going to be filled
+ * @throws RunModelException if there is any issue retrieving the data from the other process
+ */
+ private void retrieveInterprocessingTensors(List> tensors) throws RunModelException{
+ for (Tensor> tensor : tensors) {
+ try (RandomAccessFile rd =
+ new RandomAccessFile(tmpDir + File.separator
+ + this.getFilename4Tensor(tensor.getName()) + FILE_EXTENSION, "r");
+ FileChannel fc = rd.getChannel();) {
+ MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
+ ByteBuffer byteBuffer = mem.duplicate();
+ tensor.setData(MappedBufferToImgLib2.build(byteBuffer));
+ } catch (IOException e) {
+ closeModel();
+ throw new RunModelException(e.getCause().toString());
+ }
+ }
+ }
+
+ /**
+ * Create a tensor from the data contained in a file named as the parameter
+ * provided as an input + the file extension {@link #FILE_EXTENSION}.
+ * This file is produced by another process to communicate with the current process
+ * @param
+ * generic type of the tensor
+ * @param name
+ * name of the file without the extension ({@link #FILE_EXTENSION}).
+ * @return a tensor created with the data in the file
+ * @throws RunModelException if there is any problem retrieving the data and cerating the tensor
+ */
+ private < T extends RealType< T > & NativeType< T > > Tensor
+ retrieveInterprocessingTensorsByName(String name) throws RunModelException {
+ try (RandomAccessFile rd =
+ new RandomAccessFile(tmpDir + File.separator
+ + this.getFilename4Tensor(name) + FILE_EXTENSION, "r");
+ FileChannel fc = rd.getChannel();) {
+ MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size());
+ ByteBuffer byteBuffer = mem.duplicate();
+ return MappedBufferToImgLib2.buildTensor(byteBuffer);
+ } catch (IOException e) {
+ closeModel();
+ throw new RunModelException(e.getCause().toString());
+ }
+ }
+
+ /**
+ * if java bin dir contains any special char, surround it by double quotes
+ * @param javaBin
+ * java bin dir
+ * @return impored java bin dir if needed
+ */
+ private static String padSpecialJavaBin(String javaBin) {
+ String[] specialChars = new String[] {" "};
+ for (String schar : specialChars) {
+ if (javaBin.contains(schar) && PlatformDetection.isWindows()) {
+ return "\"" + javaBin + "\"";
+ }
+ }
+ return javaBin;
+ }
+
+ /**
+ * Create the arguments needed to execute tensorflow 2 in another
+ * process with the corresponding tensors
+ * @return the command used to call the separate process
+ * @throws IOException if the command needed to execute interprocessing is too long
+ * @throws URISyntaxException if there is any error with the URIs retrieved from the classes
+ */
+ private List getProcessCommandsWithoutArgs() throws IOException, URISyntaxException {
+ String javaHome = System.getProperty("java.home");
+ String javaBin = javaHome + File.separator + "bin" + File.separator + "java";
+
+ String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
+ String imglib2Path = getPathFromClass(NativeType.class);
+ if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
+ && !modelrunnerPath.contains(File.pathSeparator)))
+ modelrunnerPath = System.getProperty("java.class.path");
+ String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
+ ProtectionDomain protectionDomain = Tensorflow2Interface2_old.class.getProtectionDomain();
+ String codeSource = protectionDomain.getCodeSource().getLocation().getPath();
+ String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString());
+ f_name = new File(f_name).getAbsolutePath();
+ for (File ff : new File(f_name).getParentFile().listFiles()) {
+ if (ff.getName().startsWith(JAR_FILE_NAME) && !ff.getAbsolutePath().equals(f_name))
+ continue;
+ classpath += ff.getAbsolutePath() + File.pathSeparator;
+ }
+ String className = Tensorflow2Interface2_old.class.getName();
+ List command = new LinkedList();
+ command.add(padSpecialJavaBin(javaBin));
+ command.add("-cp");
+ command.add(classpath);
+ command.add(className);
+ command.add(modelFolder);
+ command.add(this.tmpDir);
+ return command;
+ }
+
+ /**
+ * Method that gets the path to the JAR from where a specific class is being loaded
+ * @param clazz
+ * class of interest
+ * @return the path to the JAR that contains the class
+ * @throws UnsupportedEncodingException if the url of the JAR is not encoded in UTF-8
+ */
+ private static String getPathFromClass(Class> clazz) throws UnsupportedEncodingException {
+ String classResource = clazz.getName().replace('.', '/') + ".class";
+ URL resourceUrl = clazz.getClassLoader().getResource(classResource);
+ if (resourceUrl == null) {
+ return null;
+ }
+ String urlString = resourceUrl.toString();
+ if (urlString.startsWith("jar:")) {
+ urlString = urlString.substring(4);
+ }
+ if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) {
+ urlString = urlString.substring(6);
+ } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) {
+ urlString = urlString.substring(5);
+ }
+ urlString = URLDecoder.decode(urlString, "UTF-8");
+ File file = new File(urlString);
+ String path = file.getAbsolutePath();
+ if (path.lastIndexOf(".jar!") != -1)
+ path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar";
+ return path;
+ }
+
+ /**
+ * Get temporary directory to perform the interprocessing communication in MacOSX
+ * intel and Windows
+ * @return the tmp dir
+ * @throws IOException if the files cannot be written in any of the temp dirs
+ */
+ private static String getTemporaryDir() throws IOException {
+ String tmpDir;
+ String enginesDir = getEnginesDir();
+ if (enginesDir != null && Files.isWritable(Paths.get(enginesDir))) {
+ tmpDir = enginesDir + File.separator + "temp";
+ if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs()))
+ tmpDir = enginesDir;
+ } else if (System.getenv("temp") != null
+ && Files.isWritable(Paths.get(System.getenv("temp")))) {
+ return System.getenv("temp");
+ } else if (System.getenv("TEMP") != null
+ && Files.isWritable(Paths.get(System.getenv("TEMP")))) {
+ return System.getenv("TEMP");
+ } else if (System.getenv("tmp") != null
+ && Files.isWritable(Paths.get(System.getenv("tmp")))) {
+ return System.getenv("tmp");
+ } else if (System.getenv("TMP") != null
+ && Files.isWritable(Paths.get(System.getenv("TMP")))) {
+ return System.getenv("TMP");
+ } else if (System.getProperty("java.io.tmpdir") != null
+ && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) {
+ return System.getProperty("java.io.tmpdir");
+ } else {
+ throw new IOException("Unable to find temporal directory with writting rights. "
+ + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'.");
+ }
+ return tmpDir;
+ }
+
+ /**
+ * GEt the directory where the TF2 engine is located if a temporary dir is not found
+ * @return directory of the engines
+ */
+ private static String getEnginesDir() {
+ String dir;
+ try {
+ dir = getPathFromClass(Tensorflow2Interface2_old.class);
+ } catch (UnsupportedEncodingException e) {
+ String classResource = Tensorflow2Interface2_old.class.getName().replace('.', '/') + ".class";
+ URL resourceUrl = Tensorflow2Interface2_old.class.getClassLoader().getResource(classResource);
+ if (resourceUrl == null) {
+ return null;
+ }
+ String urlString = resourceUrl.toString();
+ if (urlString.startsWith("jar:")) {
+ urlString = urlString.substring(4);
+ }
+ if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) {
+ urlString = urlString.substring(6);
+ } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) {
+ urlString = urlString.substring(5);
+ }
+ File file = new File(urlString);
+ String path = file.getAbsolutePath();
+ if (path.lastIndexOf(".jar!") != -1)
+ path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar";
+ dir = path;
+ }
+ return new File(dir).getParent();
+ }
+
+ /**
+ * Retrieve the file names used for interprocess communication
+ * @param args
+ * args provided to the main method
+ * @return a map with a list of input and output names
+ */
+ private HashMap> getInputTensorsFileNames(String[] args) {
+ List inputNames = new ArrayList();
+ List outputNames = new ArrayList();
+ if (this.tensorFilenameMap == null)
+ this.tensorFilenameMap = new HashMap();
+ for (int i = 2; i < args.length; i ++) {
+ if (args[i].endsWith(INPUT_FILE_TERMINATION)) {
+ String nameWTimestamp = args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length());
+ String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_"));
+ inputNames.add(onlyName);
+ tensorFilenameMap.put(onlyName, nameWTimestamp);
+ } else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) {
+ String nameWTimestamp = args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length());
+ String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_"));
+ outputNames.add(onlyName);
+ tensorFilenameMap.put(onlyName, nameWTimestamp);
+
+ }
+ }
+ if (inputNames.size() == 0)
+ throw new IllegalArgumentException("The args to the main method of '"
+ + Tensorflow2Interface2_old.class.toString() + "' should contain at "
+ + "least one input, defined as ' + '" + INPUT_FILE_TERMINATION + "'.");
+ if (outputNames.size() == 0)
+ throw new IllegalArgumentException("The args to the main method of '"
+ + Tensorflow2Interface2_old.class.toString() + "' should contain at "
+ + "least one output, defined as ' + '" + OUTPUT_FILE_TERMINATION + "'.");
+ HashMap> map = new HashMap>();
+ map.put(INPUTS_MAP_KEY, inputNames);
+ map.put(OUTPUTS_MAP_KEY, outputNames);
+ return map;
+ }
+
+ /**
+ * MEthod to obtain the String output of the process in case something goes wrong
+ * @param process
+ * the process that executed the TF2 model
+ * @return the String output that we would have seen on the terminal
+ * @throws IOException if the output of the terminal cannot be seen
+ */
+ private static String readProcessStringOutput(Process process) throws IOException {
+ BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
+ BufferedReader bufferedErrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
+ String text = "";
+ String line;
+ while ((line = bufferedErrReader.readLine()) != null) {
+ text += line + System.lineSeparator();
+ }
+ while ((line = bufferedReader.readLine()) != null) {
+ text += line + System.lineSeparator();
+ }
+ return text;
+ }
+}