From 6d617fd4777e1a44c2fcc16bf727ccd267c84792 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 26 Oct 2023 13:26:28 +0200 Subject: [PATCH] add model loaded check --- .../io/bioimage/modelrunner/model/Model.java | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/model/Model.java b/src/main/java/io/bioimage/modelrunner/model/Model.java index 151b13a7..037e9434 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Model.java +++ b/src/main/java/io/bioimage/modelrunner/model/Model.java @@ -37,6 +37,8 @@ import javax.xml.bind.ValidationException; +import ij.IJ; +import ij.ImagePlus; import io.bioimage.modelrunner.bioimageio.bioengine.BioEngineAvailableModels; import io.bioimage.modelrunner.bioimageio.bioengine.BioengineInterface; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; @@ -59,6 +61,7 @@ import net.imglib2.img.array.ArrayImg; import net.imglib2.img.array.ArrayImgs; import net.imglib2.img.basictypeaccess.array.FloatArray; +import net.imglib2.img.display.imagej.ImageJFunctions; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; @@ -73,6 +76,10 @@ */ public class Model { + /** + * Whether the model is loaded or not + */ + boolean loaded = false; /** * ClassLoader containing all the classes needed to use the corresponding * Deep Learning framework (engine). @@ -458,6 +465,7 @@ public void loadModel() throws LoadModelException if (engineClassLoader.isBioengine()) ((BioengineInterface) engineInstance).addServer(engineInfo.getServer()); engineClassLoader.setBaseClassLoader(); + loaded = true; } /** @@ -473,6 +481,7 @@ public void closeModel() engineInstance = null; engineClassLoader.setBaseClassLoader(); engineClassLoader = null; + loaded = false; } /** @@ -509,6 +518,8 @@ public void runModel( List< Tensor < ? > > inTensors, List< Tensor < ? > > outTe */ public & NativeType, R extends RealType & NativeType> List> runBioimageioModelOnImgLib2WithTiling(List> inputTensors) throws ValidationException, RunModelException { + if (!this.isLoaded()) + throw new RunModelException("Please first load the model."); if (descriptor == null && modelFolder == null) throw new IllegalArgumentException(""); else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile())) @@ -516,13 +527,12 @@ else if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFi else if (descriptor == null) descriptor = ModelDescriptor.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME); PatchGridCalculator tileGrid = PatchGridCalculator.build(descriptor, inputTensors); - runTiling(inputTensors, tileGrid); - return null; + return runTiling(inputTensors, tileGrid); } @SuppressWarnings("unchecked") private & NativeType, R extends RealType & NativeType> - void runTiling(List> inputTensors, PatchGridCalculator tileGrid) throws RunModelException { + List> runTiling(List> inputTensors, PatchGridCalculator tileGrid) throws RunModelException { LinkedHashMap inTileSpecs = tileGrid.getInputTensorsTileSpecs(); LinkedHashMap outTileSpecs = tileGrid.getOutputTensorsTileSpecs(); List> outputTensors = new ArrayList>(); @@ -536,6 +546,7 @@ void runTiling(List> inputTensors, PatchGridCalculator tileGrid) th new FloatType())); } doTiling(inputTensors, outputTensors, tileGrid); + return outputTensors; } private & NativeType, R extends RealType & NativeType> @@ -579,19 +590,24 @@ void doTiling(List> inputTensors, List> outputTensors, Patch this.runModel(inputTileList, outputTileList); } - } - public static & RealType> void main(String[] args) throws IOException { + public static & RealType> void main(String[] args) throws IOException, ValidationException, LoadEngineException, RunModelException { String mm = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models\\StarDist H&E Nuclei Segmentation_06092023_020924\\"; Img im = ArrayImgs.floats(new long[] {1, 511, 512, 3}); + ImagePlus imp = IJ.openImage(mm + File.separator + "sample_input_0.tif"); + imp.show(); + RandomAccessibleInterval wrapImg = ImageJFunctions.convertFloat(imp); + wrapImg = (RandomAccessibleInterval) Views.addDimension(wrapImg, 0, 0); + wrapImg = (RandomAccessibleInterval) Views.permute(wrapImg, 2, 3); + wrapImg = (RandomAccessibleInterval) Views.permute(wrapImg, 1, 2); + wrapImg = (RandomAccessibleInterval) Views.permute(wrapImg, 0, 1); List> l = new ArrayList>(); - l.add((Tensor) Tensor.build("input", "bxyc", im)); - PatchGridCalculator tileGrid = PatchGridCalculator.build(mm, l); - LinkedHashMap inTileSpecs = tileGrid.getInputTensorsTileSpecs(); - LinkedHashMap outTileSpecs = tileGrid.getOutputTensorsTileSpecs(); - TileGrid aa = TileGrid.create(inTileSpecs.get("input")); - TileGrid bb = TileGrid.create(outTileSpecs.get("output")); + l.add((Tensor) Tensor.build("input", "bxyc", wrapImg)); + Model model = createBioimageioModel(mm); + model.loadModel(); + List> out = model.runBioimageioModelOnImgLib2WithTiling(l); + ImageJFunctions.show(Views.dropSingletonDimensions(out.get(0).getData())); System.out.println(false); } @@ -652,4 +668,12 @@ public boolean isBioengine() { public EngineInfo getEngineInfo() { return engineInfo; } + + /** + * Whether the model is loaded or not + * @return + */ + public boolean isLoaded() { + return loaded; + } }