diff --git a/pom.xml b/pom.xml index 6d802f0..597f457 100644 --- a/pom.xml +++ b/pom.xml @@ -125,7 +125,7 @@ sign,deploy-to-scijava - 0.5.8 + 0.5.10-SNAPSHOT 0.2.0 diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java new file mode 100644 index 0000000..8c5c3a0 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java @@ -0,0 +1,150 @@ +package io.bioimage.modelrunner.tensorflow.v2.api020; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; + +import io.bioimage.modelrunner.apposed.appose.Types; +import io.bioimage.modelrunner.apposed.appose.Service.RequestType; +import io.bioimage.modelrunner.apposed.appose.Service.ResponseType; + +public class JavaWorker { + + private static LinkedHashMap tasks = new LinkedHashMap(); + + private final String uuid; + + private final Tensorflow2Interface ti; + + private boolean cancelRequested = false; + + public static void main(String[] args) { + + try(Scanner scanner = new Scanner(System.in)){ + Tensorflow2Interface ti; + try { + ti = new Tensorflow2Interface(false); + } catch (IOException | URISyntaxException e) { + return; + } + + while (true) { + String line; + try { + if (!scanner.hasNextLine()) break; + line = scanner.nextLine().trim(); + } catch (Exception e) { + break; + } + + if (line.isEmpty()) break; + Map request = Types.decode(line); + String uuid = (String) request.get("task"); + String requestType = (String) request.get("requestType"); + + if (requestType.equals(RequestType.EXECUTE.toString())) { + String script = (String) request.get("script"); + Map inputs = (Map) request.get("inputs"); + JavaWorker task = new JavaWorker(uuid, ti); + tasks.put(uuid, task); + task.start(script, inputs); + } else if (requestType.equals(RequestType.CANCEL.toString())) { + JavaWorker task = (JavaWorker) tasks.get(uuid); + if (task == null) { + System.err.println("No such task: " + uuid); + continue; + } + task.cancelRequested = true; + } else { + break; + } + } + } + + } + + private JavaWorker(String uuid, Tensorflow2Interface ti) { + this.uuid = uuid; + this.ti = ti; + } + + private void executeScript(String script, Map inputs) { + Map binding = new LinkedHashMap(); + binding.put("task", this); + if (inputs != null) + binding.putAll(binding); + + this.reportLaunch(); + try { + if (script.equals("loadModel")) { + ti.loadModel((String) inputs.get("modelFolder"), null); + } else if (script.equals("inference")) { + ti.runFromShmas((List) inputs.get("inputs"), (List) inputs.get("outputs")); + } else if (script.equals("close")) { + ti.closeModel(); + } + } catch(Exception ex) { + this.fail(Types.stackTrace(ex)); + return; + } + this.reportCompletion(); + } + + private void start(String script, Map inputs) { + new Thread(() -> executeScript(script, inputs), "Appose-" + this.uuid).start(); + } + + private void reportLaunch() { + respond(ResponseType.LAUNCH, null); + } + + private void reportCompletion() { + respond(ResponseType.COMPLETION, null); + } + + private void update(String message, Integer current, Integer maximum) { + LinkedHashMap args = new LinkedHashMap(); + + if (message != null) + args.put("message", message); + + if (current != null) + args.put("current", current); + + if (maximum != null) + args.put("maximum", maximum); + this.respond(ResponseType.UPDATE, args); + } + + private void respond(ResponseType responseType, Map args) { + Map response = new HashMap(); + response.put("task", uuid); + response.put("responseType", responseType); + if (args != null) + response.putAll(args); + try { + System.out.println(Types.encode(response)); + System.out.flush(); + } catch(Exception ex) { + this.fail(Types.stackTrace(ex.getCause())); + } + } + + private void cancel() { + this.respond(ResponseType.CANCELATION, null); + } + + private void fail(String error) { + Map args = null; + if (error != null) { + args = new HashMap(); + args.put("error", error); + } + respond(ResponseType.FAILURE, args); + } + +} diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java index 7de3713..008b989 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java @@ -20,9 +20,15 @@ */ package io.bioimage.modelrunner.tensorflow.v2.api020; +import com.google.gson.Gson; import com.google.protobuf.InvalidProtocolBufferException; +import io.bioimage.modelrunner.apposed.appose.Service; +import io.bioimage.modelrunner.apposed.appose.Types; +import io.bioimage.modelrunner.apposed.appose.Service.Task; +import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; import io.bioimage.modelrunner.bioimageio.download.DownloadModel; import io.bioimage.modelrunner.engine.DeepLearningEngineInterface; import io.bioimage.modelrunner.engine.EngineInfo; @@ -30,14 +36,19 @@ import io.bioimage.modelrunner.exceptions.RunModelException; import io.bioimage.modelrunner.system.PlatformDetection; import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder; import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder; import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer; import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2; +import io.bioimage.modelrunner.utils.CommonUtils; import io.bioimage.modelrunner.utils.Constants; import io.bioimage.modelrunner.utils.ZipUtils; +import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; +import net.imglib2.util.Cast; +import net.imglib2.util.Util; import java.io.BufferedReader; import java.io.File; @@ -59,8 +70,10 @@ import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.tensorflow.SavedModelBundle; @@ -68,6 +81,7 @@ import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.types.family.TType; /** * Class to that communicates with the dl-model runner, see @@ -86,66 +100,23 @@ * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando */ public class Tensorflow2Interface implements DeepLearningEngineInterface { - - private static final String[] MODEL_TAGS = { "serve", "inference", "train", - "eval", "gpu", "tpu" }; - - private static final String[] TF_MODEL_TAGS = { - "tf.saved_model.tag_constants.SERVING", - "tf.saved_model.tag_constants.INFERENCE", - "tf.saved_model.tag_constants.TRAINING", - "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU", - "tf.saved_model.tag_constants.TPU" }; - - private static final String[] SIGNATURE_CONSTANTS = { "serving_default", - "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs", - "tensorflow/serving/predict", "outputs", "inputs", - "tensorflow/serving/regress", "outputs", "train", "eval", - "tensorflow/supervised/training", "tensorflow/supervised/eval" }; - - private static final String[] TF_SIGNATURE_CONSTANTS = { - "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.CLASSIFY_INPUTS", - "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", - "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", - "tf.saved_model.signature_constants.PREDICT_INPUTS", - "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", - "tf.saved_model.signature_constants.PREDICT_OUTPUTS", - "tf.saved_model.signature_constants.REGRESS_INPUTS", - "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", - "tf.saved_model.signature_constants.REGRESS_OUTPUTS", - "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", - "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", - "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" }; - - /** - * Idetifier for the files that contain the data of the inputs - */ - final private static String INPUT_FILE_TERMINATION = "_model_input"; - - /** - * Idetifier for the files that contain the data of the outputs - */ - final private static String OUTPUT_FILE_TERMINATION = "_model_output"; - /** - * Key for the inputs in the map that retrieves the file names for interprocess communication - */ - final private static String INPUTS_MAP_KEY = "inputs"; - /** - * Key for the outputs in the map that retrieves the file names for interprocess communication - */ - final private static String OUTPUTS_MAP_KEY = "outputs"; - /** - * File extension for the temporal files used for interprocessing - */ - final private static String FILE_EXTENSION = ".dat"; /** * Name without vesion of the JAR created for this library */ private static final String JAR_FILE_NAME = "dl-modelrunner-tensorflow-"; + private static final String NAME_KEY = "name"; + private static final String SHAPE_KEY = "shape"; + private static final String DTYPE_KEY = "dtype"; + private static final String IS_INPUT_KEY = "isInput"; + private static final String MEM_NAME_KEY = "memoryName"; + + private List shmaInputList = new ArrayList(); + + private List shmaOutputList = new ArrayList(); + + private List shmaNamesList = new ArrayList(); + /** * The loaded Tensorflow 2 model */ @@ -158,26 +129,14 @@ public class Tensorflow2Interface implements DeepLearningEngineInterface { * Whether the execution needs interprocessing (MacOS Interl) or not */ private boolean interprocessing = false; - /** - * TEmporary dir where to store temporary files - */ - private String tmpDir; /** * Folde containing the model that is being executed */ private String modelFolder; - /** - * List of temporary files used for interprocessing communication - */ - private List listTempFiles; - /** - * HashMap that maps tensor to the temporal file name for interprocessing - */ - private HashMap tensorFilenameMap; /** * Process where the model is being loaded and executed */ - Process process; + Service runner; /** * TODO the interprocessing is executed for every OS @@ -185,15 +144,11 @@ public class Tensorflow2Interface implements DeepLearningEngineInterface { * executed is Windows or Mac or not to know if it is going to need interprocessing * or not * @throws IOException if the temporary dir is not found + * @throws URISyntaxException */ - public Tensorflow2Interface() throws IOException + public Tensorflow2Interface() throws IOException, URISyntaxException { - boolean isWin = PlatformDetection.isWindows(); - boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); - if (true || (isWin && isIntel)) { - interprocessing = true; - tmpDir = getTemporaryDir(); - } + this(true); } /** @@ -203,19 +158,23 @@ public Tensorflow2Interface() throws IOException * @param doInterprocessing * whether to do interprocessing or not * @throws IOException if the temp dir is not found + * @throws URISyntaxException */ - private Tensorflow2Interface(boolean doInterprocessing) throws IOException + protected Tensorflow2Interface(boolean doInterprocessing) throws IOException, URISyntaxException { - if (!doInterprocessing) { - interprocessing = false; - } else { - boolean isWin = PlatformDetection.isMacOS(); - boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); - if (isWin && isIntel) { - interprocessing = true; - tmpDir = getTemporaryDir(); - } - } + interprocessing = doInterprocessing; + if (this.interprocessing) { + runner = getRunner(); + runner.debug((text) -> System.err.println(text)); + } + } + + private Service getRunner() throws IOException, URISyntaxException { + List args = getProcessCommandsWithoutArgs(); + String[] argArr = new String[args.size()]; + args.toArray(argArr); + + return new Service(new File("."), argArr); } /** @@ -231,8 +190,14 @@ public void loadModel(String modelFolder, String modelSource) throws LoadModelException { this.modelFolder = modelFolder; - if (interprocessing) + if (interprocessing) { + try { + launchModelLoadOnProcess(); + } catch (IOException | InterruptedException e) { + throw new LoadModelException(Types.stackTrace(e)); + } return; + } try { checkModelUnzipped(); } catch (Exception e) { @@ -249,6 +214,19 @@ public void loadModel(String modelFolder, String modelSource) } } + private void launchModelLoadOnProcess() throws IOException, InterruptedException { + HashMap args = new HashMap(); + args.put("modelFolder", modelFolder); + Task task = runner.task("loadModel", args); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + } + /** * Check if an unzipped tensorflow model exists in the model folder, * and if not look for it and unzip it @@ -260,7 +238,7 @@ private void checkModelUnzipped() throws LoadModelException, IOException, Except if (new File(modelFolder, "variables").isDirectory() && new File(modelFolder, "saved_model.pb").isFile()) return; - unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME)); + unzipTfWeights(ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME)); } /** @@ -301,7 +279,8 @@ private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelExceptio * and modifies the output list with the results obtained */ @Override - public void run(List> inputTensors, List> outputTensors) + public & NativeType, R extends RealType & NativeType> + void run(List> inputTensors, List> outputTensors) throws RunModelException { if (interprocessing) { @@ -311,30 +290,65 @@ public void run(List> inputTensors, List> outputTensors) Session session = model.session(); Session.Runner runner = session.runner(); List inputListNames = new ArrayList(); - List> inTensors = - new ArrayList>(); + List inTensors = new ArrayList(); int c = 0; - for (Tensor tt : inputTensors) { + for (Tensor tt : inputTensors) { inputListNames.add(tt.getName()); - org.tensorflow.Tensor inT = TensorBuilder.build(tt); + TType inT = TensorBuilder.build(tt); inTensors.add(inT); String inputName = getModelInputName(tt.getName(), c ++); runner.feed(inputName, inT); } c = 0; - for (Tensor tt : outputTensors) + for (Tensor tt : outputTensors) runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); // Run runner - List> resultPatchTensors = runner.run(); + List resultPatchTensors = runner.run(); // Fill the agnostic output tensors list with data from the inference result fillOutputTensors(resultPatchTensors, outputTensors); // Close the remaining resources - session.close(); - for (org.tensorflow.Tensor tt : inTensors) { + for (TType tt : inTensors) { tt.close(); } - for (org.tensorflow.Tensor tt : resultPatchTensors) { + for (org.tensorflow.Tensor tt : resultPatchTensors) { + tt.close(); + } + } + + protected void runFromShmas(List inputs, List outputs) throws IOException { + Session session = model.session(); + Session.Runner runner = session.runner(); + + List inTensors = new ArrayList(); + int c = 0; + for (String ee : inputs) { + Map decoded = Types.decode(ee); + SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY)); + TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma); + if (PlatformDetection.isWindows()) shma.close(); + inTensors.add(inT); + String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++); + runner.feed(inputName, inT); + } + + c = 0; + for (String ee : outputs) + runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++)); + // Run runner + List resultPatchTensors = runner.run(); + + // Fill the agnostic output tensors list with data from the inference result + c = 0; + for (String ee : outputs) { + Map decoded = Types.decode(ee); + ShmBuilder.build((TType) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY)); + } + // Close the remaining resources + for (TType tt : inTensors) { + tt.close(); + } + for (org.tensorflow.Tensor tt : resultPatchTensors) { tt.close(); } } @@ -349,32 +363,110 @@ public void run(List> inputTensors, List> outputTensors) * expected results of the model * @throws RunModelException if there is any issue running the model */ - public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { - createTensorsForInterprocessing(inputTensors); - createTensorsForInterprocessing(outputTensors); + public & NativeType, R extends RealType & NativeType> + void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { + shmaInputList = new ArrayList(); + shmaOutputList = new ArrayList(); + List encIns = modifyForWinCmd(encodeInputs(inputTensors)); + List encOuts = modifyForWinCmd(encodeOutputs(outputTensors)); + LinkedHashMap args = new LinkedHashMap(); + args.put("inputs", encIns); + args.put("outputs", encOuts); + try { - List args = getProcessCommandsWithoutArgs(); - for (Tensor tensor : inputTensors) {args.add(getFilename4Tensor(tensor.getName()) + INPUT_FILE_TERMINATION);} - for (Tensor tensor : outputTensors) {args.add(getFilename4Tensor(tensor.getName()) + OUTPUT_FILE_TERMINATION);} - ProcessBuilder builder = new ProcessBuilder(args); - builder.redirectOutput(ProcessBuilder.Redirect.INHERIT); - builder.redirectError(ProcessBuilder.Redirect.INHERIT); - process = builder.start(); - int result = process.waitFor(); - process.destroy(); - if (result != 0) - throw new RunModelException("Error executing the Tensorflow 2 model in" - + " a separate process. The process was not terminated correctly." - + System.lineSeparator() + readProcessStringOutput(process)); - } catch (RunModelException e) { - closeModel(); - throw e; + Task task = runner.task("inference", args); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + for (int i = 0; i < outputTensors.size(); i ++) { + String name = (String) Types.decode(encOuts.get(i)).get(MEM_NAME_KEY); + SharedMemoryArray shm = shmaOutputList.stream() + .filter(ss -> ss.getName().equals(name)).findFirst().orElse(null); + if (shm == null) { + shm = SharedMemoryArray.read(name); + shmaOutputList.add(shm); + } + RandomAccessibleInterval rai = shm.getSharedRAI(); + outputTensors.get(i).setData(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai)))); + } } catch (Exception e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); + closeShmas(); + if (e instanceof RunModelException) + throw (RunModelException) e; + throw new RunModelException(Types.stackTrace(e)); } - - retrieveInterprocessingTensors(outputTensors); + closeShmas(); + } + + private void closeShmas() { + shmaInputList.forEach(shm -> { + try { shm.close(); } catch (IOException e1) { e1.printStackTrace();} + }); + shmaInputList = null; + shmaOutputList.forEach(shm -> { + try { shm.close(); } catch (IOException e1) { e1.printStackTrace();} + }); + shmaOutputList = null; + } + + private static List modifyForWinCmd(List ins){ + if (!PlatformDetection.isWindows()) + return ins; + List newIns = new ArrayList(); + for (String ii : ins) + newIns.add("\"" + ii.replace("\"", "\\\"") + "\""); + return newIns; + } + + + private & NativeType> List encodeInputs(List> inputTensors) { + List encodedInputTensors = new ArrayList(); + Gson gson = new Gson(); + for (Tensor tt : inputTensors) { + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + shmaInputList.add(shma); + HashMap map = new HashMap(); + map.put(NAME_KEY, tt.getName()); + map.put(SHAPE_KEY, tt.getShape()); + map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData())); + map.put(IS_INPUT_KEY, true); + map.put(MEM_NAME_KEY, shma.getName()); + encodedInputTensors.add(gson.toJson(map)); + } + return encodedInputTensors; + } + + + private & NativeType> + List encodeOutputs(List> outputTensors) { + Gson gson = new Gson(); + List encodedOutputTensors = new ArrayList(); + for (Tensor tt : outputTensors) { + HashMap map = new HashMap(); + map.put(NAME_KEY, tt.getName()); + map.put(IS_INPUT_KEY, false); + if (!tt.isEmpty()) { + map.put(SHAPE_KEY, tt.getShape()); + map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData())); + SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true); + shmaOutputList.add(shma); + map.put(MEM_NAME_KEY, shma.getName()); + } else if (PlatformDetection.isWindows()){ + SharedMemoryArray shma = SharedMemoryArray.create(0); + shmaOutputList.add(shma); + map.put(MEM_NAME_KEY, shma.getName()); + } else { + String memName = SharedMemoryArray.createShmName(); + map.put(MEM_NAME_KEY, memName); + shmaNamesList.add(memName); + } + encodedOutputTensors.add(gson.toJson(map)); + } + return encodedOutputTensors; } /** @@ -386,9 +478,10 @@ public void runInterprocessing(List> inputTensors, List> out * @throws RunModelException If the number of tensors expected is not the same * as the number of Tensors outputed by the model */ - public static void fillOutputTensors( + public static & NativeType> + void fillOutputTensors( List> outputTfTensors, - List> outputTensors) throws RunModelException + List> outputTensors) throws RunModelException { if (outputTfTensors.size() != outputTensors.size()) throw new RunModelException(outputTfTensors.size(), outputTensors.size()); @@ -406,21 +499,31 @@ public static void fillOutputTensors( */ @Override public void closeModel() { + if (this.interprocessing && runner != null) { + Task task; + try { + task = runner.task("close"); + task.waitFor(); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(Types.stackTrace(e)); + } + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + this.runner.close(); + return; + } else if (this.interprocessing) { + return; + } sig = null; if (model != null) { model.session().close(); model.close(); } model = null; - if (listTempFiles == null) - return; - for (File ff : listTempFiles) { - if (ff.exists()) - ff.delete(); - } - listTempFiles = null; - if (process != null) - process.destroyForcibly(); } // TODO make only one @@ -506,157 +609,8 @@ public static String getModelOutputName(String outputName, int i) { * @throws RunModelException if there is any error running the model */ public static void main(String[] args) throws LoadModelException, IOException, RunModelException { - Tensorflow2Interface tt = new Tensorflow2Interface(false); - - tt.loadModel("/home/carlos/Desktop/Fiji.app/models/model_03bioimageio", null); - // Unpack the args needed - if (args.length < 4) - throw new IllegalArgumentException("Error exectuting Tensorflow 2, " - + "at least 5 arguments are required:" + System.lineSeparator() - + " - Folder where the model is located" + System.lineSeparator() - + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() - + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() - + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() - + " - ...." + System.lineSeparator() - + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() - + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator() - + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() - + " - ...." + System.lineSeparator() - + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() - ); - String modelFolder = args[0]; - if (!(new File(modelFolder).isDirectory())) { - throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " - + "should be an existing directory containing a Tensorflow 2 model."); - } - - Tensorflow2Interface tfInterface = new Tensorflow2Interface(false); - tfInterface.tmpDir = args[1]; - if (!(new File(args[1]).isDirectory())) { - throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " - + "should be an existing directory."); - } - - tfInterface.loadModel(modelFolder, modelFolder); - - HashMap> map = tfInterface.getInputTensorsFileNames(args); - List inputNames = map.get(INPUTS_MAP_KEY); - List> inputList = inputNames.stream().map(n -> { - try { - return tfInterface.retrieveInterprocessingTensorsByName(n); - } catch (RunModelException e) { - return null; - } - }).collect(Collectors.toList()); - List outputNames = map.get(OUTPUTS_MAP_KEY); - List> outputList = outputNames.stream().map(n -> { - try { - return tfInterface.retrieveInterprocessingTensorsByName(n); - } catch (RunModelException e) { - return null; - } - }).collect(Collectors.toList()); - tfInterface.run(inputList, outputList); - tfInterface.createTensorsForInterprocessing(outputList); - } - - /** - * Get the name of the temporary file associated to the tensor name - * @param name - * name of the tensor - * @return file name associated to the tensor - */ - private String getFilename4Tensor(String name) { - if (tensorFilenameMap == null) - tensorFilenameMap = new HashMap(); - if (tensorFilenameMap.get(name) != null) - return tensorFilenameMap.get(name); - LocalDateTime now = LocalDateTime.now(); - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmssSSS"); - String newName = name + "_" + now.format(formatter); - tensorFilenameMap.put(name, newName); - return tensorFilenameMap.get(name); } - /** - * Create a temporary file for each of the tensors in the list to communicate with - * the separate process in MacOS Intel and Windows systems - * @param tensors - * list of tensors to be sent - * @throws RunModelException if there is any error converting the tensors - */ - private void createTensorsForInterprocessing(List> tensors) throws RunModelException{ - if (this.listTempFiles == null) - this.listTempFiles = new ArrayList(); - for (Tensor tensor : tensors) { - long lenFile = ImgLib2ToMappedBuffer.findTotalLengthFile(tensor); - File ff = new File(tmpDir + File.separator + getFilename4Tensor(tensor.getName()) + FILE_EXTENSION); - if (!ff.exists()) { - ff.deleteOnExit(); - this.listTempFiles.add(ff); - } - try (RandomAccessFile rd = - new RandomAccessFile(ff, "rw"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); - ByteBuffer byteBuffer = mem.duplicate(); - ImgLib2ToMappedBuffer.build(tensor, byteBuffer); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Retrieves the data of the tensors contained in the input list from the output - * generated by the independent process - * @param tensors - * list of tensors that are going to be filled - * @throws RunModelException if there is any issue retrieving the data from the other process - */ - private void retrieveInterprocessingTensors(List> tensors) throws RunModelException{ - for (Tensor tensor : tensors) { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator - + this.getFilename4Tensor(tensor.getName()) + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - } - - /** - * Create a tensor from the data contained in a file named as the parameter - * provided as an input + the file extension {@link #FILE_EXTENSION}. - * This file is produced by another process to communicate with the current process - * @param - * generic type of the tensor - * @param name - * name of the file without the extension ({@link #FILE_EXTENSION}). - * @return a tensor created with the data in the file - * @throws RunModelException if there is any problem retrieving the data and cerating the tensor - */ - private < T extends RealType< T > & NativeType< T > > Tensor - retrieveInterprocessingTensorsByName(String name) throws RunModelException { - try (RandomAccessFile rd = - new RandomAccessFile(tmpDir + File.separator - + this.getFilename4Tensor(name) + FILE_EXTENSION, "r"); - FileChannel fc = rd.getChannel();) { - MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); - ByteBuffer byteBuffer = mem.duplicate(); - return MappedBufferToImgLib2.buildTensor(byteBuffer); - } catch (IOException e) { - closeModel(); - throw new RunModelException(e.getCause().toString()); - } - } - /** * if java bin dir contains any special char, surround it by double quotes * @param javaBin @@ -684,12 +638,7 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt String javaHome = System.getProperty("java.home"); String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; - String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); - String imglib2Path = getPathFromClass(NativeType.class); - if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class") - && !modelrunnerPath.contains(File.pathSeparator))) - modelrunnerPath = System.getProperty("java.class.path"); - String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator; + String classpath = getCurrentClasspath(); ProtectionDomain protectionDomain = Tensorflow2Interface.class.getProtectionDomain(); String codeSource = protectionDomain.getCodeSource().getLocation().getPath(); String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString()); @@ -699,17 +648,34 @@ private List getProcessCommandsWithoutArgs() throws IOException, URISynt continue; classpath += ff.getAbsolutePath() + File.pathSeparator; } - String className = Tensorflow2Interface.class.getName(); + String className = JavaWorker.class.getName(); List command = new LinkedList(); command.add(padSpecialJavaBin(javaBin)); command.add("-cp"); command.add(classpath); command.add(className); - command.add(modelFolder); - command.add(this.tmpDir); return command; } + private static String getCurrentClasspath() throws UnsupportedEncodingException { + + String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); + String imglib2Path = getPathFromClass(NativeType.class); + String gsonPath = getPathFromClass(Gson.class); + String jnaPath = getPathFromClass(com.sun.jna.Library.class); + String jnaPlatformPath = getPathFromClass(com.sun.jna.platform.FileUtils.class); + if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class") + && !modelrunnerPath.contains(File.pathSeparator))) + modelrunnerPath = System.getProperty("java.class.path"); + modelrunnerPath = System.getProperty("java.class.path"); + String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator; + classpath = classpath + gsonPath + File.pathSeparator; + classpath = classpath + jnaPath + File.pathSeparator; + classpath = classpath + jnaPlatformPath + File.pathSeparator; + + return classpath; + } + /** * Method that gets the path to the JAR from where a specific class is being loaded * @param clazz @@ -740,41 +706,6 @@ private static String getPathFromClass(Class clazz) throws UnsupportedEncodin return path; } - /** - * Get temporary directory to perform the interprocessing communication in MacOSX - * intel and Windows - * @return the tmp dir - * @throws IOException if the files cannot be written in any of the temp dirs - */ - private static String getTemporaryDir() throws IOException { - String tmpDir; - String enginesDir = getEnginesDir(); - if (enginesDir != null && Files.isWritable(Paths.get(enginesDir))) { - tmpDir = enginesDir + File.separator + "temp"; - if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) - tmpDir = enginesDir; - } else if (System.getenv("temp") != null - && Files.isWritable(Paths.get(System.getenv("temp")))) { - return System.getenv("temp"); - } else if (System.getenv("TEMP") != null - && Files.isWritable(Paths.get(System.getenv("TEMP")))) { - return System.getenv("TEMP"); - } else if (System.getenv("tmp") != null - && Files.isWritable(Paths.get(System.getenv("tmp")))) { - return System.getenv("tmp"); - } else if (System.getenv("TMP") != null - && Files.isWritable(Paths.get(System.getenv("TMP")))) { - return System.getenv("TMP"); - } else if (System.getProperty("java.io.tmpdir") != null - && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { - return System.getProperty("java.io.tmpdir"); - } else { - throw new IOException("Unable to find temporal directory with writting rights. " - + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); - } - return tmpDir; - } - /** * GEt the directory where the TF2 engine is located if a temporary dir is not found * @return directory of the engines @@ -806,64 +737,4 @@ private static String getEnginesDir() { } return new File(dir).getParent(); } - - /** - * Retrieve the file names used for interprocess communication - * @param args - * args provided to the main method - * @return a map with a list of input and output names - */ - private HashMap> getInputTensorsFileNames(String[] args) { - List inputNames = new ArrayList(); - List outputNames = new ArrayList(); - if (this.tensorFilenameMap == null) - this.tensorFilenameMap = new HashMap(); - for (int i = 2; i < args.length; i ++) { - if (args[i].endsWith(INPUT_FILE_TERMINATION)) { - String nameWTimestamp = args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length()); - String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); - inputNames.add(onlyName); - tensorFilenameMap.put(onlyName, nameWTimestamp); - } else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) { - String nameWTimestamp = args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length()); - String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); - outputNames.add(onlyName); - tensorFilenameMap.put(onlyName, nameWTimestamp); - - } - } - if (inputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow2Interface.class.toString() + "' should contain at " - + "least one input, defined as ' + '" + INPUT_FILE_TERMINATION + "'."); - if (outputNames.size() == 0) - throw new IllegalArgumentException("The args to the main method of '" - + Tensorflow2Interface.class.toString() + "' should contain at " - + "least one output, defined as ' + '" + OUTPUT_FILE_TERMINATION + "'."); - HashMap> map = new HashMap>(); - map.put(INPUTS_MAP_KEY, inputNames); - map.put(OUTPUTS_MAP_KEY, outputNames); - return map; - } - - /** - * MEthod to obtain the String output of the process in case something goes wrong - * @param process - * the process that executed the TF2 model - * @return the String output that we would have seen on the terminal - * @throws IOException if the output of the terminal cannot be seen - */ - private static String readProcessStringOutput(Process process) throws IOException { - BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream())); - BufferedReader bufferedErrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); - String text = ""; - String line; - while ((line = bufferedErrReader.readLine()) != null) { - text += line + System.lineSeparator(); - } - while ((line = bufferedReader.readLine()) != null) { - text += line + System.lineSeparator(); - } - return text; - } } diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java new file mode 100644 index 0000000..64adaa7 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface2_old.java @@ -0,0 +1,869 @@ +/*- + * #%L + * This project complements the DL-model runner acting as the engine that works loading models + * and making inference with Java 0.2.0 API for Tensorflow 2. + * %% + * Copyright (C) 2022 - 2023 Institut Pasteur and BioImage.IO developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package io.bioimage.modelrunner.tensorflow.v2.api020; + +import com.google.protobuf.InvalidProtocolBufferException; + +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; +import io.bioimage.modelrunner.bioimageio.download.DownloadModel; +import io.bioimage.modelrunner.engine.DeepLearningEngineInterface; +import io.bioimage.modelrunner.engine.EngineInfo; +import io.bioimage.modelrunner.exceptions.LoadModelException; +import io.bioimage.modelrunner.exceptions.RunModelException; +import io.bioimage.modelrunner.system.PlatformDetection; +import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder; +import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder; +import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer; +import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2; +import io.bioimage.modelrunner.utils.Constants; +import io.bioimage.modelrunner.utils.ZipUtils; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.RandomAccessFile; +import java.io.UnsupportedEncodingException; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLDecoder; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.ProtectionDomain; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Collectors; + +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.proto.framework.MetaGraphDef; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; + +/** + * Class to that communicates with the dl-model runner, see + * @see dlmodelrunner + * to execute Tensorflow 2 models. This class is compatible with TF2 Java API 0.2.0. + * This class implements the interface {@link DeepLearningEngineInterface} to get the + * agnostic {@link io.bioimage.modelrunner.tensor.Tensor}, convert them into + * {@link org.tensorflow.Tensor}, execute a Tensorflow 2 Deep Learning model on them and + * convert the results back to {@link io.bioimage.modelrunner.tensor.Tensor} to send them + * to the main program in an agnostic manner. + * + * {@link ImgLib2Builder}. Creates ImgLib2 images for the backend + * of {@link io.bioimage.modelrunner.tensor.Tensor} from {@link org.tensorflow.Tensor} + * {@link TensorBuilder}. Converts {@link io.bioimage.modelrunner.tensor.Tensor} into {@link org.tensorflow.Tensor} + * + * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + */ +public class Tensorflow2Interface2_old implements DeepLearningEngineInterface { + + private static final String[] MODEL_TAGS = { "serve", "inference", "train", + "eval", "gpu", "tpu" }; + + private static final String[] TF_MODEL_TAGS = { + "tf.saved_model.tag_constants.SERVING", + "tf.saved_model.tag_constants.INFERENCE", + "tf.saved_model.tag_constants.TRAINING", + "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU", + "tf.saved_model.tag_constants.TPU" }; + + private static final String[] SIGNATURE_CONSTANTS = { "serving_default", + "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs", + "tensorflow/serving/predict", "outputs", "inputs", + "tensorflow/serving/regress", "outputs", "train", "eval", + "tensorflow/supervised/training", "tensorflow/supervised/eval" }; + + private static final String[] TF_SIGNATURE_CONSTANTS = { + "tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", + "tf.saved_model.signature_constants.CLASSIFY_INPUTS", + "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", + "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", + "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", + "tf.saved_model.signature_constants.PREDICT_INPUTS", + "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", + "tf.saved_model.signature_constants.PREDICT_OUTPUTS", + "tf.saved_model.signature_constants.REGRESS_INPUTS", + "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", + "tf.saved_model.signature_constants.REGRESS_OUTPUTS", + "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", + "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", + "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", + "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME" }; + + /** + * Idetifier for the files that contain the data of the inputs + */ + final private static String INPUT_FILE_TERMINATION = "_model_input"; + + /** + * Idetifier for the files that contain the data of the outputs + */ + final private static String OUTPUT_FILE_TERMINATION = "_model_output"; + /** + * Key for the inputs in the map that retrieves the file names for interprocess communication + */ + final private static String INPUTS_MAP_KEY = "inputs"; + /** + * Key for the outputs in the map that retrieves the file names for interprocess communication + */ + final private static String OUTPUTS_MAP_KEY = "outputs"; + /** + * File extension for the temporal files used for interprocessing + */ + final private static String FILE_EXTENSION = ".dat"; + /** + * Name without vesion of the JAR created for this library + */ + private static final String JAR_FILE_NAME = "dl-modelrunner-tensorflow-"; + + /** + * The loaded Tensorflow 2 model + */ + private static SavedModelBundle model; + /** + * Internal object of the Tensorflow model + */ + private static SignatureDef sig; + /** + * Whether the execution needs interprocessing (MacOS Interl) or not + */ + private boolean interprocessing = false; + /** + * TEmporary dir where to store temporary files + */ + private String tmpDir; + /** + * Folde containing the model that is being executed + */ + private String modelFolder; + /** + * List of temporary files used for interprocessing communication + */ + private List listTempFiles; + /** + * HashMap that maps tensor to the temporal file name for interprocessing + */ + private HashMap tensorFilenameMap; + /** + * Process where the model is being loaded and executed + */ + Process process; + + /** + * TODO the interprocessing is executed for every OS + * Constructor that detects whether the operating system where it is being + * executed is Windows or Mac or not to know if it is going to need interprocessing + * or not + * @throws IOException if the temporary dir is not found + */ + public Tensorflow2Interface2_old() throws IOException + { + boolean isWin = PlatformDetection.isWindows(); + boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); + if (true || (isWin && isIntel)) { + interprocessing = true; + tmpDir = getTemporaryDir(); + } + } + + /** + * Private constructor that can only be launched from the class to create a separate + * process to avoid the conflicts that occur in the same process between TF2 and TF1/Pytorch + * in Windows and Mac + * @param doInterprocessing + * whether to do interprocessing or not + * @throws IOException if the temp dir is not found + */ + private Tensorflow2Interface2_old(boolean doInterprocessing) throws IOException + { + if (!doInterprocessing) { + interprocessing = false; + } else { + boolean isWin = PlatformDetection.isMacOS(); + boolean isIntel = PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64); + if (isWin && isIntel) { + interprocessing = true; + tmpDir = getTemporaryDir(); + } + } + } + + /** + * {@inheritDoc} + * + * Load a Tensorflow 2 model. If the machine where the code is + * being executed is a MacOS Intel or Windows, the model will be loaded in + * a separate process each time the method {@link #run(List, List)} + * is called + */ + @Override + public void loadModel(String modelFolder, String modelSource) + throws LoadModelException + { + this.modelFolder = modelFolder; + if (interprocessing) + return; + try { + checkModelUnzipped(); + } catch (Exception e) { + throw new LoadModelException(e.toString()); + } + model = SavedModelBundle.load(this.modelFolder, "serve"); + byte[] byteGraph = model.metaGraphDef().toByteArray(); + try { + sig = MetaGraphDef.parseFrom(byteGraph).getSignatureDefOrThrow( + "serving_default"); + } + catch (InvalidProtocolBufferException e) { + System.out.println("Invalid graph"); + } + } + + /** + * Check if an unzipped tensorflow model exists in the model folder, + * and if not look for it and unzip it + * @throws LoadModelException if no model is found + * @throws IOException if there is any error unzipping the model + * @throws Exception if there is any error related to model packaging + */ + private void checkModelUnzipped() throws LoadModelException, IOException, Exception { + if (new File(modelFolder, "variables").isDirectory() + && new File(modelFolder, "saved_model.pb").isFile()) + return; + unzipTfWeights(ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME)); + } + + /** + * Method that unzips the tensorflow model zip into the variables + * folder and .pb file, if they are saved in a zip + * @throws LoadModelException if not zip file is found + * @throws IOException if there is any error unzipping + */ + private void unzipTfWeights(ModelDescriptor descriptor) throws LoadModelException, IOException { + if (new File(modelFolder, "tf_weights.zip").isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(modelFolder + File.separator + "tf_weights.zip", modelFolder); + } else if ( descriptor.getWeights().getAllSuportedWeightNames() + .contains(EngineInfo.getBioimageioTfKey()) ) { + String source = descriptor.getWeights().gettAllSupportedWeightObjects().stream() + .filter(ww -> ww.getFramework().equals(EngineInfo.getBioimageioTfKey())) + .findFirst().get().getSource(); + if (new File(source).isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(new File(source).getAbsolutePath(), modelFolder); + } else if (new File(modelFolder, source).isFile()) { + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(new File(modelFolder, source).getAbsolutePath(), modelFolder); + } else { + source = DownloadModel.getFileNameFromURLString(source); + System.out.println("Unzipping model..."); + ZipUtils.unzipFolder(modelFolder + File.separator + source, modelFolder); + } + } else { + throw new LoadModelException("No model file was found in the model folder"); + } + } + + /** + * {@inheritDoc} + * + * Run a Tensorflow2 model on the data provided by the {@link Tensor} input list + * and modifies the output list with the results obtained + */ + @Override + public void run(List> inputTensors, List> outputTensors) + throws RunModelException + { + if (interprocessing) { + runInterprocessing(inputTensors, outputTensors); + return; + } + Session session = model.session(); + Session.Runner runner = session.runner(); + List inputListNames = new ArrayList(); + List> inTensors = + new ArrayList>(); + int c = 0; + for (Tensor tt : inputTensors) { + inputListNames.add(tt.getName()); + org.tensorflow.Tensor inT = TensorBuilder.build(tt); + inTensors.add(inT); + String inputName = getModelInputName(tt.getName(), c ++); + runner.feed(inputName, inT); + } + c = 0; + for (Tensor tt : outputTensors) + runner = runner.fetch(getModelOutputName(tt.getName(), c ++)); + // Run runner + List> resultPatchTensors = runner.run(); + + // Fill the agnostic output tensors list with data from the inference result + fillOutputTensors(resultPatchTensors, outputTensors); + // Close the remaining resources + session.close(); + for (org.tensorflow.Tensor tt : inTensors) { + tt.close(); + } + for (org.tensorflow.Tensor tt : resultPatchTensors) { + tt.close(); + } + } + + /** + * MEthod only used in MacOS Intel and Windows systems that makes all the arrangements + * to create another process, communicate the model info and tensors to the other + * process and then retrieve the results of the other process + * @param inputTensors + * tensors that are going to be run on the model + * @param outputTensors + * expected results of the model + * @throws RunModelException if there is any issue running the model + */ + public void runInterprocessing(List> inputTensors, List> outputTensors) throws RunModelException { + createTensorsForInterprocessing(inputTensors); + createTensorsForInterprocessing(outputTensors); + try { + List args = getProcessCommandsWithoutArgs(); + for (Tensor tensor : inputTensors) {args.add(getFilename4Tensor(tensor.getName()) + INPUT_FILE_TERMINATION);} + for (Tensor tensor : outputTensors) {args.add(getFilename4Tensor(tensor.getName()) + OUTPUT_FILE_TERMINATION);} + ProcessBuilder builder = new ProcessBuilder(args); + builder.redirectOutput(ProcessBuilder.Redirect.INHERIT); + builder.redirectError(ProcessBuilder.Redirect.INHERIT); + process = builder.start(); + int result = process.waitFor(); + process.destroy(); + if (result != 0) + throw new RunModelException("Error executing the Tensorflow 2 model in" + + " a separate process. The process was not terminated correctly." + + System.lineSeparator() + readProcessStringOutput(process)); + } catch (RunModelException e) { + closeModel(); + throw e; + } catch (Exception e) { + closeModel(); + throw new RunModelException(e.getCause().toString()); + } + + retrieveInterprocessingTensors(outputTensors); + } + + /** + * Create the list a list of output tensors agnostic to the Deep Learning + * engine that can be readable by the dl-modelrunner + * + * @param outputTfTensors an List containing dl-modelrunner tensors + * @param outputTensors the names given to the tensors by the model + * @throws RunModelException If the number of tensors expected is not the same + * as the number of Tensors outputed by the model + */ + public static void fillOutputTensors( + List> outputTfTensors, + List> outputTensors) throws RunModelException + { + if (outputTfTensors.size() != outputTensors.size()) + throw new RunModelException(outputTfTensors.size(), outputTensors.size()); + for (int i = 0; i < outputTfTensors.size(); i++) { + outputTensors.get(i).setData(ImgLib2Builder.build(outputTfTensors.get(i))); + } + } + + /** + * {@inheritDoc} + * + * Close the Tensorflow 2 {@link #model} and {@link #sig}. For + * MacOS Intel and Windows systems it also deletes the temporary files created to + * communicate with the other process + */ + @Override + public void closeModel() { + sig = null; + if (model != null) { + model.session().close(); + model.close(); + } + model = null; + if (listTempFiles == null) + return; + for (File ff : listTempFiles) { + if (ff.exists()) + ff.delete(); + } + listTempFiles = null; + if (process != null) + process.destroyForcibly(); + } + + // TODO make only one + /** + * Retrieves the readable input name from the graph signature definition given + * the signature input name. + * + * @param inputName Signature input name. + * @param i position of the input of interest in the list of inputs + * @return The readable input name. + */ + public static String getModelInputName(String inputName, int i) { + TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null); + if (inputInfo == null) { + inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i); + } + if (inputInfo != null) { + String modelInputName = inputInfo.getName(); + if (modelInputName != null) { + if (modelInputName.endsWith(":0")) { + return modelInputName.substring(0, modelInputName.length() - 2); + } + else { + return modelInputName; + } + } + else { + return inputName; + } + } + return inputName; + } + + /** + * Retrieves the readable output name from the graph signature definition + * given the signature output name. + * + * @param outputName Signature output name. + * @param i position of the input of interest in the list of inputs + * @return The readable output name. + */ + public static String getModelOutputName(String outputName, int i) { + TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null); + if (outputInfo == null) { + outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i); + } + if (outputInfo != null) { + String modelOutputName = outputInfo.getName(); + if (modelOutputName.endsWith(":0")) { + return modelOutputName.substring(0, modelOutputName.length() - 2); + } + else { + return modelOutputName; + } + } + else { + return outputName; + } + } + + + /** + * Methods to run interprocessing and bypass the errors that occur in MacOS intel + * with the compatibility between TF2 and TF1/Pytorch + * This method checks that the arguments are correct, retrieves the input and output + * tensors, loads the model, makes inference with it and finally sends the tensors + * to the original process + * + * @param args + * arguments of the program: + * - Path to the model folder + * - Path to a temporary dir + * - Name of the input 0 + * - Name of the input 1 + * - ... + * - Name of the output n + * - Name of the output 0 + * - Name of the output 1 + * - ... + * - Name of the output n + * @throws LoadModelException if there is any error loading the model + * @throws IOException if there is any error reading or writing any file or with the paths + * @throws RunModelException if there is any error running the model + */ + public static void main(String[] args) throws LoadModelException, IOException, RunModelException { + Tensorflow2Interface2_old tt = new Tensorflow2Interface2_old(false); + + tt.loadModel("/home/carlos/Desktop/Fiji.app/models/model_03bioimageio", null); + // Unpack the args needed + if (args.length < 4) + throw new IllegalArgumentException("Error exectuting Tensorflow 2, " + + "at least 5 arguments are required:" + System.lineSeparator() + + " - Folder where the model is located" + System.lineSeparator() + + " - Temporary dir where the memory mapped files are located" + System.lineSeparator() + + " - Name of the model input followed by the String + '_model_input'" + System.lineSeparator() + + " - Name of the second model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() + + " - ...." + System.lineSeparator() + + " - Name of the nth model input (if it exists) followed by the String + '_model_input'" + System.lineSeparator() + + " - Name of the model output followed by the String + '_model_output'" + System.lineSeparator() + + " - Name of the second model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() + + " - ...." + System.lineSeparator() + + " - Name of the nth model output (if it exists) followed by the String + '_model_output'" + System.lineSeparator() + ); + String modelFolder = args[0]; + if (!(new File(modelFolder).isDirectory())) { + throw new IllegalArgumentException("Argument 0 of the main method, '" + modelFolder + "' " + + "should be an existing directory containing a Tensorflow 2 model."); + } + + Tensorflow2Interface2_old tfInterface = new Tensorflow2Interface2_old(false); + tfInterface.tmpDir = args[1]; + if (!(new File(args[1]).isDirectory())) { + throw new IllegalArgumentException("Argument 1 of the main method, '" + args[1] + "' " + + "should be an existing directory."); + } + + tfInterface.loadModel(modelFolder, modelFolder); + + HashMap> map = tfInterface.getInputTensorsFileNames(args); + List inputNames = map.get(INPUTS_MAP_KEY); + List> inputList = inputNames.stream().map(n -> { + try { + return tfInterface.retrieveInterprocessingTensorsByName(n); + } catch (RunModelException e) { + return null; + } + }).collect(Collectors.toList()); + List outputNames = map.get(OUTPUTS_MAP_KEY); + List> outputList = outputNames.stream().map(n -> { + try { + return tfInterface.retrieveInterprocessingTensorsByName(n); + } catch (RunModelException e) { + return null; + } + }).collect(Collectors.toList()); + tfInterface.run(inputList, outputList); + tfInterface.createTensorsForInterprocessing(outputList); + } + + /** + * Get the name of the temporary file associated to the tensor name + * @param name + * name of the tensor + * @return file name associated to the tensor + */ + private String getFilename4Tensor(String name) { + if (tensorFilenameMap == null) + tensorFilenameMap = new HashMap(); + if (tensorFilenameMap.get(name) != null) + return tensorFilenameMap.get(name); + LocalDateTime now = LocalDateTime.now(); + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmssSSS"); + String newName = name + "_" + now.format(formatter); + tensorFilenameMap.put(name, newName); + return tensorFilenameMap.get(name); + } + + /** + * Create a temporary file for each of the tensors in the list to communicate with + * the separate process in MacOS Intel and Windows systems + * @param tensors + * list of tensors to be sent + * @throws RunModelException if there is any error converting the tensors + */ + private void createTensorsForInterprocessing(List> tensors) throws RunModelException{ + if (this.listTempFiles == null) + this.listTempFiles = new ArrayList(); + for (Tensor tensor : tensors) { + long lenFile = ImgLib2ToMappedBuffer.findTotalLengthFile(tensor); + File ff = new File(tmpDir + File.separator + getFilename4Tensor(tensor.getName()) + FILE_EXTENSION); + if (!ff.exists()) { + ff.deleteOnExit(); + this.listTempFiles.add(ff); + } + try (RandomAccessFile rd = + new RandomAccessFile(ff, "rw"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_WRITE, 0, lenFile); + ByteBuffer byteBuffer = mem.duplicate(); + ImgLib2ToMappedBuffer.build(tensor, byteBuffer); + } catch (IOException e) { + closeModel(); + throw new RunModelException(e.getCause().toString()); + } + } + } + + /** + * Retrieves the data of the tensors contained in the input list from the output + * generated by the independent process + * @param tensors + * list of tensors that are going to be filled + * @throws RunModelException if there is any issue retrieving the data from the other process + */ + private void retrieveInterprocessingTensors(List> tensors) throws RunModelException{ + for (Tensor tensor : tensors) { + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + + this.getFilename4Tensor(tensor.getName()) + FILE_EXTENSION, "r"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); + ByteBuffer byteBuffer = mem.duplicate(); + tensor.setData(MappedBufferToImgLib2.build(byteBuffer)); + } catch (IOException e) { + closeModel(); + throw new RunModelException(e.getCause().toString()); + } + } + } + + /** + * Create a tensor from the data contained in a file named as the parameter + * provided as an input + the file extension {@link #FILE_EXTENSION}. + * This file is produced by another process to communicate with the current process + * @param + * generic type of the tensor + * @param name + * name of the file without the extension ({@link #FILE_EXTENSION}). + * @return a tensor created with the data in the file + * @throws RunModelException if there is any problem retrieving the data and cerating the tensor + */ + private < T extends RealType< T > & NativeType< T > > Tensor + retrieveInterprocessingTensorsByName(String name) throws RunModelException { + try (RandomAccessFile rd = + new RandomAccessFile(tmpDir + File.separator + + this.getFilename4Tensor(name) + FILE_EXTENSION, "r"); + FileChannel fc = rd.getChannel();) { + MappedByteBuffer mem = fc.map(FileChannel.MapMode.READ_ONLY, 0, fc.size()); + ByteBuffer byteBuffer = mem.duplicate(); + return MappedBufferToImgLib2.buildTensor(byteBuffer); + } catch (IOException e) { + closeModel(); + throw new RunModelException(e.getCause().toString()); + } + } + + /** + * if java bin dir contains any special char, surround it by double quotes + * @param javaBin + * java bin dir + * @return impored java bin dir if needed + */ + private static String padSpecialJavaBin(String javaBin) { + String[] specialChars = new String[] {" "}; + for (String schar : specialChars) { + if (javaBin.contains(schar) && PlatformDetection.isWindows()) { + return "\"" + javaBin + "\""; + } + } + return javaBin; + } + + /** + * Create the arguments needed to execute tensorflow 2 in another + * process with the corresponding tensors + * @return the command used to call the separate process + * @throws IOException if the command needed to execute interprocessing is too long + * @throws URISyntaxException if there is any error with the URIs retrieved from the classes + */ + private List getProcessCommandsWithoutArgs() throws IOException, URISyntaxException { + String javaHome = System.getProperty("java.home"); + String javaBin = javaHome + File.separator + "bin" + File.separator + "java"; + + String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class); + String imglib2Path = getPathFromClass(NativeType.class); + if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class") + && !modelrunnerPath.contains(File.pathSeparator))) + modelrunnerPath = System.getProperty("java.class.path"); + String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator; + ProtectionDomain protectionDomain = Tensorflow2Interface2_old.class.getProtectionDomain(); + String codeSource = protectionDomain.getCodeSource().getLocation().getPath(); + String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString()); + f_name = new File(f_name).getAbsolutePath(); + for (File ff : new File(f_name).getParentFile().listFiles()) { + if (ff.getName().startsWith(JAR_FILE_NAME) && !ff.getAbsolutePath().equals(f_name)) + continue; + classpath += ff.getAbsolutePath() + File.pathSeparator; + } + String className = Tensorflow2Interface2_old.class.getName(); + List command = new LinkedList(); + command.add(padSpecialJavaBin(javaBin)); + command.add("-cp"); + command.add(classpath); + command.add(className); + command.add(modelFolder); + command.add(this.tmpDir); + return command; + } + + /** + * Method that gets the path to the JAR from where a specific class is being loaded + * @param clazz + * class of interest + * @return the path to the JAR that contains the class + * @throws UnsupportedEncodingException if the url of the JAR is not encoded in UTF-8 + */ + private static String getPathFromClass(Class clazz) throws UnsupportedEncodingException { + String classResource = clazz.getName().replace('.', '/') + ".class"; + URL resourceUrl = clazz.getClassLoader().getResource(classResource); + if (resourceUrl == null) { + return null; + } + String urlString = resourceUrl.toString(); + if (urlString.startsWith("jar:")) { + urlString = urlString.substring(4); + } + if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) { + urlString = urlString.substring(6); + } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) { + urlString = urlString.substring(5); + } + urlString = URLDecoder.decode(urlString, "UTF-8"); + File file = new File(urlString); + String path = file.getAbsolutePath(); + if (path.lastIndexOf(".jar!") != -1) + path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar"; + return path; + } + + /** + * Get temporary directory to perform the interprocessing communication in MacOSX + * intel and Windows + * @return the tmp dir + * @throws IOException if the files cannot be written in any of the temp dirs + */ + private static String getTemporaryDir() throws IOException { + String tmpDir; + String enginesDir = getEnginesDir(); + if (enginesDir != null && Files.isWritable(Paths.get(enginesDir))) { + tmpDir = enginesDir + File.separator + "temp"; + if (!(new File(tmpDir).isDirectory()) && !(new File(tmpDir).mkdirs())) + tmpDir = enginesDir; + } else if (System.getenv("temp") != null + && Files.isWritable(Paths.get(System.getenv("temp")))) { + return System.getenv("temp"); + } else if (System.getenv("TEMP") != null + && Files.isWritable(Paths.get(System.getenv("TEMP")))) { + return System.getenv("TEMP"); + } else if (System.getenv("tmp") != null + && Files.isWritable(Paths.get(System.getenv("tmp")))) { + return System.getenv("tmp"); + } else if (System.getenv("TMP") != null + && Files.isWritable(Paths.get(System.getenv("TMP")))) { + return System.getenv("TMP"); + } else if (System.getProperty("java.io.tmpdir") != null + && Files.isWritable(Paths.get(System.getProperty("java.io.tmpdir")))) { + return System.getProperty("java.io.tmpdir"); + } else { + throw new IOException("Unable to find temporal directory with writting rights. " + + "Please either allow writting on the system temporal folder or on '" + enginesDir + "'."); + } + return tmpDir; + } + + /** + * GEt the directory where the TF2 engine is located if a temporary dir is not found + * @return directory of the engines + */ + private static String getEnginesDir() { + String dir; + try { + dir = getPathFromClass(Tensorflow2Interface2_old.class); + } catch (UnsupportedEncodingException e) { + String classResource = Tensorflow2Interface2_old.class.getName().replace('.', '/') + ".class"; + URL resourceUrl = Tensorflow2Interface2_old.class.getClassLoader().getResource(classResource); + if (resourceUrl == null) { + return null; + } + String urlString = resourceUrl.toString(); + if (urlString.startsWith("jar:")) { + urlString = urlString.substring(4); + } + if (urlString.startsWith("file:/") && PlatformDetection.isWindows()) { + urlString = urlString.substring(6); + } else if (urlString.startsWith("file:/") && !PlatformDetection.isWindows()) { + urlString = urlString.substring(5); + } + File file = new File(urlString); + String path = file.getAbsolutePath(); + if (path.lastIndexOf(".jar!") != -1) + path = path.substring(0, path.lastIndexOf(".jar!")) + ".jar"; + dir = path; + } + return new File(dir).getParent(); + } + + /** + * Retrieve the file names used for interprocess communication + * @param args + * args provided to the main method + * @return a map with a list of input and output names + */ + private HashMap> getInputTensorsFileNames(String[] args) { + List inputNames = new ArrayList(); + List outputNames = new ArrayList(); + if (this.tensorFilenameMap == null) + this.tensorFilenameMap = new HashMap(); + for (int i = 2; i < args.length; i ++) { + if (args[i].endsWith(INPUT_FILE_TERMINATION)) { + String nameWTimestamp = args[i].substring(0, args[i].length() - INPUT_FILE_TERMINATION.length()); + String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); + inputNames.add(onlyName); + tensorFilenameMap.put(onlyName, nameWTimestamp); + } else if (args[i].endsWith(OUTPUT_FILE_TERMINATION)) { + String nameWTimestamp = args[i].substring(0, args[i].length() - OUTPUT_FILE_TERMINATION.length()); + String onlyName = nameWTimestamp.substring(0, nameWTimestamp.lastIndexOf("_")); + outputNames.add(onlyName); + tensorFilenameMap.put(onlyName, nameWTimestamp); + + } + } + if (inputNames.size() == 0) + throw new IllegalArgumentException("The args to the main method of '" + + Tensorflow2Interface2_old.class.toString() + "' should contain at " + + "least one input, defined as ' + '" + INPUT_FILE_TERMINATION + "'."); + if (outputNames.size() == 0) + throw new IllegalArgumentException("The args to the main method of '" + + Tensorflow2Interface2_old.class.toString() + "' should contain at " + + "least one output, defined as ' + '" + OUTPUT_FILE_TERMINATION + "'."); + HashMap> map = new HashMap>(); + map.put(INPUTS_MAP_KEY, inputNames); + map.put(OUTPUTS_MAP_KEY, outputNames); + return map; + } + + /** + * MEthod to obtain the String output of the process in case something goes wrong + * @param process + * the process that executed the TF2 model + * @return the String output that we would have seen on the terminal + * @throws IOException if the output of the terminal cannot be seen + */ + private static String readProcessStringOutput(Process process) throws IOException { + BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader bufferedErrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String text = ""; + String line; + while ((line = bufferedErrReader.readLine()) != null) { + text += line + System.lineSeparator(); + } + while ((line = bufferedReader.readLine()) != null) { + text += line + System.lineSeparator(); + } + return text; + } +}