Skip to content

Commit

Permalink
convert to long arrays and start managing input tile ceration by ref
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2023
1 parent 00ef328 commit f080bb1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
22 changes: 20 additions & 2 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import javax.xml.bind.ValidationException;

Expand Down Expand Up @@ -61,6 +62,9 @@
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

/**
* Class that manages a Deep Learning model to load it and run it.
Expand Down Expand Up @@ -545,8 +549,22 @@ void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, Patch
int nTiles = 1;
for (int i : tilesPerAxis) nTiles *= i;

for (int i = 0; i < nTiles; i ++) {

for (int j = 0; j < nTiles; j ++) {
int tileCount = j + 0;
IntStream.range(0, inputTensors.size()).mapToObj(i -> {
if (!inputTensors.get(i).isImage())
return inputTensors.get(i);
RandomAccessibleInterval<R> tileRai = Views.interval(
Views.extendBorder(inputTensors.get(i).getData()),
inTileGrids.get(inputTensors.get(i).getName()).getTilePostionsInImage().get(tileCount),
(long[]) inTileGrids.get(inputTensors.get(i).getName()).getTileSize());
/*
RandomAccessibleInterval<R> tileRai = Views.interval(
Views.extendBorder(inputTensors.get(i).getData()),
Intervals.expand(inputTensors.get(i).getData(), 50));
*/
return Tensor.build(inputTensors.get(i).getName(), inputTensors.get(i).getAxesOrderString(), tileRai);
});
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -281,7 +282,8 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T>
+ " rdf.yaml file for tensor '" + spec.getName() + "': " + e.getMessage());
}
}
return computePatchSpecs(spec, rai, spec.getProcessingPatch());
long[] tileSize = Arrays.stream(spec.getProcessingPatch()).mapToLong(i -> i).toArray();
return computePatchSpecs(spec, rai, tileSize);
}

/**
Expand All @@ -296,7 +298,7 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T>
*
* @return an object containing the specs needed to perform patching for the particular tensor
*/
private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T> rai, int[] tileSize)
private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T> rai, long[] tileSize)
{
int[][] paddingSize = new int[2][tileSize.length];
// REgard that the input halo represents the output halo + offset
Expand Down Expand Up @@ -337,17 +339,17 @@ private PatchSpec computePatchSpecsForOutputTensor(TensorSpec tensorSpec, PatchS
// REgard that the input halo represents the output halo + offset
// and must be divisible by 0.5.
int[][] paddingSize = refTilesSpec.getPatchPaddingSize();
int[] tileSize;
long[] tileSize;
long[] shapeLong;
if (tensorSpec.getShape().getReferenceInput() == null) {
tileSize = tensorSpec.getShape().getPatchRecomendedSize();
tileSize = Arrays.stream(tensorSpec.getShape().getPatchRecomendedSize()).mapToLong(i -> i).toArray();
shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length())
.map(i -> (tileSize[(int) i] - paddingSize[0][(int) i] - paddingSize[0][(int) i]) * inputTileGrid[(int) i])
.toArray();
} else {
tileSize = IntStream.range(0, tensorSpec.getAxesOrder().length())
.map(i -> (int) (refTilesSpec.getPatchInputSize()[i] * tensorSpec.getShape().getScale()[i] + 2 * tensorSpec.getShape().getOffset()[i]))
.toArray();
.mapToLong(i -> i).toArray();
shapeLong = LongStream.range(0, tensorSpec.getAxesOrder().length())
.map(i -> (int) (refTilesSpec.getTensorDims()[(int) i] * tensorSpec.getShape().getScale()[(int) i]
+ 2 * tensorSpec.getShape().getOffset()[(int) i])).toArray();
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/io/bioimage/modelrunner/tiling/PatchSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class PatchSpec
/**
* Size of the input patch. Following "xyczb" axes order
*/
private int[] patchInputSize;
private long[] patchInputSize;
/**
* Size of the number of patches per axis. Following "xyczb" axes order
*/
Expand Down Expand Up @@ -67,7 +67,7 @@ public class PatchSpec
* The padding size used on each patch.
* @return The create patch specification.
*/
public static PatchSpec create(String tensorName, int[] patchInputSize, int[] patchGridSize,
public static PatchSpec create(String tensorName, long[] patchInputSize, int[] patchGridSize,
int[][] patchPaddingSize, long[] tensorDims)
{
PatchSpec ps = new PatchSpec();
Expand Down Expand Up @@ -170,7 +170,7 @@ public long[] getTensorDims() {
/**
* @return Input patch size. The patch taken from the input sequence including the halo.
*/
public int[] getPatchInputSize()
public long[] getPatchInputSize()
{
return patchInputSize;
}
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class TileGrid
/**
* Size of the input patch. Following the tensor axes order
*/
private int[] tileSize;
private long[] tileSize;
/**
* Size of roi of each tile, following the tensor axes order
*/
Expand Down Expand Up @@ -73,10 +73,10 @@ public static TileGrid create(PatchSpec tileSpecs)

for (int j = 0; j < tileCount; j ++) {
int[] patchIndex = IndexingUtils.flatIntoMultidimensionalIndex(j, gridSize);
int[] patchSize = tileSpecs.getPatchInputSize();
long[] patchSize = tileSpecs.getPatchInputSize();
int[][] padSize = tileSpecs.getPatchPaddingSize();
int[] roiSize = IntStream.range(0, patchIndex.length)
.map(i -> patchSize[i] - padSize[0][i] - padSize[1][i]).toArray();
.map(i -> (int) patchSize[i] - padSize[0][i] - padSize[1][i]).toArray();
ps.roiSize = roiSize;
ps.roiPositionsInTile.add(IntStream.range(0, padSize[0].length).mapToLong(i -> (long) padSize[0][i]).toArray());
long[] roiStart = LongStream.range(0, patchIndex.length)
Expand All @@ -94,7 +94,7 @@ public String getTensorName() {
return tensorName;
}

public int[] getTileSize() {
public long[] getTileSize() {
return this.tileSize;
}

Expand Down

0 comments on commit f080bb1

Please sign in to comment.