From f080bb11125e6aa1abe89ee8cd9cf27651a8ccb3 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 25 Oct 2023 20:11:14 +0200 Subject: [PATCH] convert to long arrays and start managing input tile ceration by ref --- .../io/bioimage/modelrunner/model/Model.java | 22 +++++++++++++++++-- .../tiling/PatchGridCalculator.java | 12 +++++----- .../modelrunner/tiling/PatchSpec.java | 6 ++--- .../bioimage/modelrunner/tiling/TileGrid.java | 8 +++---- 4 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/model/Model.java b/src/main/java/io/bioimage/modelrunner/model/Model.java index 758332f4..32828fd9 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Model.java +++ b/src/main/java/io/bioimage/modelrunner/model/Model.java @@ -33,6 +33,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; +import java.util.stream.IntStream; import javax.xml.bind.ValidationException; @@ -61,6 +62,9 @@ import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Intervals; +import net.imglib2.view.IntervalView; +import net.imglib2.view.Views; /** * Class that manages a Deep Learning model to load it and run it. @@ -545,8 +549,22 @@ void doTiling(List> inputTensors, List> outputTensors, Patch int nTiles = 1; for (int i : tilesPerAxis) nTiles *= i; - for (int i = 0; i < nTiles; i ++) { - + for (int j = 0; j < nTiles; j ++) { + int tileCount = j + 0; + IntStream.range(0, inputTensors.size()).mapToObj(i -> { + if (!inputTensors.get(i).isImage()) + return inputTensors.get(i); + RandomAccessibleInterval tileRai = Views.interval( + Views.extendBorder(inputTensors.get(i).getData()), + inTileGrids.get(inputTensors.get(i).getName()).getTilePostionsInImage().get(tileCount), + (long[]) inTileGrids.get(inputTensors.get(i).getName()).getTileSize()); + /* + RandomAccessibleInterval tileRai = Views.interval( + Views.extendBorder(inputTensors.get(i).getData()), + Intervals.expand(inputTensors.get(i).getData(), 50)); + */ + return Tensor.build(inputTensors.get(i).getName(), inputTensors.get(i).getAxesOrderString(), tileRai); + }); } } diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java index f6c5f7a5..3b90476c 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; +import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -281,7 +282,8 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval + " rdf.yaml file for tensor '" + spec.getName() + "': " + e.getMessage()); } } - return computePatchSpecs(spec, rai, spec.getProcessingPatch()); + long[] tileSize = Arrays.stream(spec.getProcessingPatch()).mapToLong(i -> i).toArray(); + return computePatchSpecs(spec, rai, tileSize); } /** @@ -296,7 +298,7 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval * * @return an object containing the specs needed to perform patching for the particular tensor */ - private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval rai, int[] tileSize) + private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval rai, long[] tileSize) { int[][] paddingSize = new int[2][tileSize.length]; // REgard that the input halo represents the output halo + offset @@ -337,17 +339,17 @@ private PatchSpec computePatchSpecsForOutputTensor(TensorSpec tensorSpec, PatchS // REgard that the input halo represents the output halo + offset // and must be divisible by 0.5. int[][] paddingSize = refTilesSpec.getPatchPaddingSize(); - int[] tileSize; + long[] tileSize; long[] shapeLong; if (tensorSpec.getShape().getReferenceInput() == null) { - tileSize = tensorSpec.getShape().getPatchRecomendedSize(); + tileSize = Arrays.stream(tensorSpec.getShape().getPatchRecomendedSize()).mapToLong(i -> i).toArray(); shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length()) .map(i -> (tileSize[(int) i] - paddingSize[0][(int) i] - paddingSize[0][(int) i]) * inputTileGrid[(int) i]) .toArray(); } else { tileSize = IntStream.range(0, tensorSpec.getAxesOrder().length()) .map(i -> (int) (refTilesSpec.getPatchInputSize()[i] * tensorSpec.getShape().getScale()[i] + 2 * tensorSpec.getShape().getOffset()[i])) - .toArray(); + .mapToLong(i -> i).toArray(); shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length()) .map(i -> (int) (refTilesSpec.getTensorDims()[(int) i] * tensorSpec.getShape().getScale()[(int) i] + 2 * tensorSpec.getShape().getOffset()[(int) i])).toArray(); diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java index 1658cacd..5567350d 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java @@ -37,7 +37,7 @@ public class PatchSpec /** * Size of the input patch. Following "xyczb" axes order */ - private int[] patchInputSize; + private long[] patchInputSize; /** * Size of the number of patches per axis. Following "xyczb" axes order */ @@ -67,7 +67,7 @@ public class PatchSpec * The padding size used on each patch. * @return The create patch specification. */ - public static PatchSpec create(String tensorName, int[] patchInputSize, int[] patchGridSize, + public static PatchSpec create(String tensorName, long[] patchInputSize, int[] patchGridSize, int[][] patchPaddingSize, long[] tensorDims) { PatchSpec ps = new PatchSpec(); @@ -170,7 +170,7 @@ public long[] getTensorDims() { /** * @return Input patch size. The patch taken from the input sequence including the halo. */ - public int[] getPatchInputSize() + public long[] getPatchInputSize() { return patchInputSize; } diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java index 5ae48f5e..5cfec00f 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java @@ -38,7 +38,7 @@ public class TileGrid /** * Size of the input patch. Following the tensor axes order */ - private int[] tileSize; + private long[] tileSize; /** * Size of roi of each tile, following the tensor axes order */ @@ -73,10 +73,10 @@ public static TileGrid create(PatchSpec tileSpecs) for (int j = 0; j < tileCount; j ++) { int[] patchIndex = IndexingUtils.flatIntoMultidimensionalIndex(j, gridSize); - int[] patchSize = tileSpecs.getPatchInputSize(); + long[] patchSize = tileSpecs.getPatchInputSize(); int[][] padSize = tileSpecs.getPatchPaddingSize(); int[] roiSize = IntStream.range(0, patchIndex.length) - .map(i -> patchSize[i] - padSize[0][i] - padSize[1][i]).toArray(); + .map(i -> (int) patchSize[i] - padSize[0][i] - padSize[1][i]).toArray(); ps.roiSize = roiSize; ps.roiPositionsInTile.add(IntStream.range(0, padSize[0].length).mapToLong(i -> (long) padSize[0][i]).toArray()); long[] roiStart = LongStream.range(0, patchIndex.length) @@ -94,7 +94,7 @@ public String getTensorName() { return tensorName; } - public int[] getTileSize() { + public long[] getTileSize() { return this.tileSize; }