diff --git a/src/main/java/io/bioimage/modelrunner/model/Model.java b/src/main/java/io/bioimage/modelrunner/model/Model.java index aeb4938e..2a872431 100644 --- a/src/main/java/io/bioimage/modelrunner/model/Model.java +++ b/src/main/java/io/bioimage/modelrunner/model/Model.java @@ -61,6 +61,7 @@ import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Cast; +import net.imglib2.util.Util; /** * Class that manages a Deep Learning model to load it and run it. @@ -536,7 +537,7 @@ void runModel( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors ) engineClassLoader.setEngineClassLoader(); ArrayList> inTensorsFloat = new ArrayList>(); for (Tensor tt : inTensors) { - if (tt.getData().getAt(0) instanceof FloatType) + if (Util.getTypeFromInterval(tt.getData()) instanceof FloatType) inTensorsFloat.add(Cast.unchecked(tt)); else inTensorsFloat.add(Tensor.createCopyOfTensorInWantedDataType( tt, new FloatType() )); diff --git a/src/main/java/io/bioimage/modelrunner/transformations/AbstractTensorPixelTransformation.java b/src/main/java/io/bioimage/modelrunner/transformations/AbstractTensorPixelTransformation.java index 33d90cb5..d0f449f6 100644 --- a/src/main/java/io/bioimage/modelrunner/transformations/AbstractTensorPixelTransformation.java +++ b/src/main/java/io/bioimage/modelrunner/transformations/AbstractTensorPixelTransformation.java @@ -20,20 +20,12 @@ package io.bioimage.modelrunner.transformations; import io.bioimage.modelrunner.tensor.Tensor; -import net.imglib2.RandomAccessibleInterval; import net.imglib2.loops.LoopBuilder; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.IntegerType; import net.imglib2.type.numeric.RealType; -import net.imglib2.type.numeric.integer.ByteType; -import net.imglib2.type.numeric.integer.IntType; -import net.imglib2.type.numeric.integer.LongType; -import net.imglib2.type.numeric.integer.ShortType; -import net.imglib2.type.numeric.integer.UnsignedByteType; -import net.imglib2.type.numeric.integer.UnsignedIntType; -import net.imglib2.type.numeric.integer.UnsignedShortType; -import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Util; /** * Abstract classes for tensor transformations where a new pixel value can be @@ -69,7 +61,7 @@ public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply( public < R extends RealType< R > & NativeType< R > > void applyInPlace( final Tensor< R > input ) { - if (input.getData().getAt(0) instanceof IntegerType && dun != null) { + if (Util.getTypeFromInterval(input.getData()) instanceof IntegerType && dun != null) { LoopBuilder .setImages( input.getData() ) .multiThreaded() diff --git a/src/main/java/io/bioimage/modelrunner/transformations/ScaleLinearTransformation.java b/src/main/java/io/bioimage/modelrunner/transformations/ScaleLinearTransformation.java index b3e5fb51..1e03aa30 100644 --- a/src/main/java/io/bioimage/modelrunner/transformations/ScaleLinearTransformation.java +++ b/src/main/java/io/bioimage/modelrunner/transformations/ScaleLinearTransformation.java @@ -233,47 +233,5 @@ void scaleLinear(RandomAccessibleInterval rai, double gain, double offset) { .multiThreaded() .forEachPixel( i -> i.setReal((i.getRealDouble() * gain + offset) ) ); } - /** - * TODO remove - if (rai.getAt(0) instanceof ByteType) { - LoopBuilder.setImages( (RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((byte) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof UnsignedByteType) { - LoopBuilder.setImages( (RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof ShortType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((short) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof UnsignedShortType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof IntType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof UnsignedIntType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((long) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof LongType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((long) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof FloatType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((float) (i.get() * gain + offset) ) ); - } else if (rai.getAt(0) instanceof DoubleType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((double) (i.get() * gain + offset) ) ); - } else { - throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai)); - } - */ } } diff --git a/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java b/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java index f9280703..2480916e 100644 --- a/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java +++ b/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java @@ -152,75 +152,6 @@ private < R extends RealType< R > & NativeType< R > > void globalScale( final Te cursor.next(); flatArr[count ++] = cursor.get().getRealDouble(); } - /* - * TODO remove - if (rai.getAt(0) instanceof ByteType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof UnsignedByteType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof ShortType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof UnsignedShortType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof IntType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof UnsignedIntType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof LongType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof FloatType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else if (rai.getAt(0) instanceof DoubleType) { - final Cursor cursor = (Cursor) flatImage.cursor(); - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = (double) cursor.get().get(); - } - } else { - throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai)); - } - */ Arrays.sort(flatArr); int percentilePos = (int) (flatSize * percentile); @@ -321,47 +252,5 @@ void scaleRange(RandomAccessibleInterval rai, double maxPercentileVal, double .multiThreaded() .forEachPixel( i -> i.setReal(((i.getRealDouble() - minPercentileVal) / (diff + eps)) ) ); } - /** - * TODO remove - if (rai.getAt(0) instanceof ByteType) { - LoopBuilder.setImages( (RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((byte) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof UnsignedByteType) { - LoopBuilder.setImages( (RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof ShortType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((short) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof UnsignedShortType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof IntType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof UnsignedIntType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof LongType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else if (rai.getAt(0) instanceof FloatType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((float) ((i.get() - minPercentileVal) / (diff + eps))) ); - } else if (rai.getAt(0) instanceof DoubleType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((double) ((i.get() - minPercentileVal) / (diff + eps)) ) ); - } else { - throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai)); - } - */ } } diff --git a/src/main/java/io/bioimage/modelrunner/transformations/ZeroMeanUnitVarianceTransformation.java b/src/main/java/io/bioimage/modelrunner/transformations/ZeroMeanUnitVarianceTransformation.java index b65f13c9..8baa8d51 100644 --- a/src/main/java/io/bioimage/modelrunner/transformations/ZeroMeanUnitVarianceTransformation.java +++ b/src/main/java/io/bioimage/modelrunner/transformations/ZeroMeanUnitVarianceTransformation.java @@ -389,47 +389,5 @@ void zeroMeanUnitVariance(RandomAccessibleInterval rai, double mean, double s .multiThreaded() .forEachPixel( i -> i.setReal(((i.getRealDouble() - mean) / (std + eps)) ) ); } - /** - * TODO remove - if (rai.getAt(0) instanceof ByteType) { - LoopBuilder.setImages( (RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((byte) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof UnsignedByteType) { - LoopBuilder.setImages( (RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof ShortType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((short) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof UnsignedShortType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof IntType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof UnsignedIntType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((long) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof LongType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((long) ((i.get() - mean) / (std + eps)) ) ); - } else if (rai.getAt(0) instanceof FloatType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((float) ((i.get() - mean) / (std + eps))) ); - } else if (rai.getAt(0) instanceof DoubleType) { - LoopBuilder.setImages((RandomAccessibleInterval) rai ) - .multiThreaded() - .forEachPixel( i -> i.set((double) ((i.get() - mean) / (std + eps)) ) ); - } else { - throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai)); - } - */ } }