Skip to content

Commit

Permalink
start filling get methods
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 21, 2024
1 parent a1e2724 commit 60e8945
Showing 1 changed file with 67 additions and 16 deletions.
83 changes: 67 additions & 16 deletions src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ public class TileCalculator {
private List<TileInfo> outputTileInfo;

private final ModelDescriptor descriptor;

private final TileFactory factory;


private final LinkedHashMap<String, PatchSpec> input = new LinkedHashMap<String, PatchSpec>();

private final LinkedHashMap<String, PatchSpec> output = new LinkedHashMap<String, PatchSpec>();

private final LinkedHashMap<String, TileGrid> inputGrid = new LinkedHashMap<String, TileGrid>();

private final LinkedHashMap<String, TileGrid> outputGrid = new LinkedHashMap<String, TileGrid>();

private TileCalculator(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
this.descriptor = descriptor;
this.inputTileInfo = tileInfoList;
validate();
this.factory = TileFactory.init(descriptor);
calculate();
}

public static TileCalculator build(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
Expand Down Expand Up @@ -271,7 +273,16 @@ private void calculate() {
for (TensorSpec tt : this.descriptor.getInputTensors()) {
TileInfo tile = inputTileInfo.stream()
.filter(til -> til.getName().equals(tt.getTensorID())).findFirst().orElse(null);
input.put(tt.getTensorID(), computePatchSpecs(tt, tile));
PatchSpec patch = computePatchSpecs(tt, tile);
input.put(tt.getTensorID(), patch);
inputGrid.put(tt.getTensorID(), TileGrid.create(patch));
}
for (TensorSpec tt : this.descriptor.getOutputTensors()) {
TileInfo tile = outputTileInfo.stream()
.filter(til -> til.getName().equals(tt.getTensorID())).findFirst().orElse(null);
PatchSpec patch = computePatchSpecs(tt, tile);
output.put(tt.getTensorID(), patch);
outputGrid.put(tt.getTensorID(), TileGrid.create(patch));
}
}

Expand Down Expand Up @@ -331,7 +342,17 @@ public void getTileList() {

}

public void getInsertionPoints(String tensorName, int nTile, String axesOrder) {
public void getInputInsertionPoints(String tensorId, int nTile, String axes) {
TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Input tensor '" + tensorId + "' does not require tiling.");

}

public void getOutputInsertionPoints(String tensorId, int nTile, String axes) {
TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Output tensor '" + tensorId + "' does not require tiling.");

}

Expand All @@ -357,16 +378,44 @@ public List<String> getOutputTensorNames() {
*
* @return size of the tile that is going to be used to process the image
*/
public long[] getTileSize(String tensorId) {
return null;
public long[] getInputTileSize(String tensorId) {
TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Input tensor '" + tensorId + "' does not require tiling.");
return tile.getProposedTileDimensions();
}

/**
*
* @return size of the tile that is going to be used to process the image
*/
public long[] getOutputTileSize(String tensorId) {
TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Output tensor '" + tensorId + "' does not require tiling.");
return tile.getProposedTileDimensions();
}

/**
*
* @return size of the roi of each of the tiles that is going to be used to process the image
*/
public int[] getRoiSize(String tensorId) {
return null;
public int[] getInputRoiSize(String tensorId) {
TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Input tensor '" + tensorId + "' does not require tiling.");
return this.inputGrid.get(tensorId).getRoiSize();
}

/**
*
* @return size of the roi of each of the tiles that is going to be used to process the image
*/
public int[] getOutputRoiSize(String tensorId) {
TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Output tensor '" + tensorId + "' does not require tiling.");
return this.outputGrid.get(tensorId).getRoiSize();
}

/**
Expand All @@ -376,7 +425,10 @@ public int[] getRoiSize(String tensorId) {
* The positions might be negative as the image that is going to be processed might have padding on the edges
*/
public List<long[]> getTilePostionsInputImage(String tensorId) {
return null;
TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Input tensor '" + tensorId + "' does not require tiling.");
return inputGrid.get(tensorId).getTilePostionsInImage();
}

/**
Expand All @@ -386,12 +438,11 @@ public List<long[]> getTilePostionsInputImage(String tensorId) {
* The positions might be negative as the image that is going to be processed might have padding on the edges
*/
public List<long[]> getTilePostionsOutputImage(String tensorId) {
return null;
TileInfo tile = this.outputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("Output tensor '" + tensorId + "' does not require tiling.");
return outputGrid.get(tensorId).getTilePostionsInImage();
}

private static void checkTileDims(TensorSpec tensor, TileInfo tile) {

}

/**
* Convert the array following given axes order into
Expand Down

0 comments on commit 60e8945

Please sign in to comment.