Skip to content

Commit

Permalink
correct getting image data type
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 2, 2024
1 parent 5ddf27d commit b82c705
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 206 deletions.
3 changes: 2 additions & 1 deletion src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

/**
* Class that manages a Deep Learning model to load it and run it.
Expand Down Expand Up @@ -536,7 +537,7 @@ void runModel( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )
engineClassLoader.setEngineClassLoader();
ArrayList<Tensor<FloatType>> inTensorsFloat = new ArrayList<Tensor<FloatType>>();
for (Tensor<T> tt : inTensors) {
if (tt.getData().getAt(0) instanceof FloatType)
if (Util.getTypeFromInterval(tt.getData()) instanceof FloatType)
inTensorsFloat.add(Cast.unchecked(tt));
else
inTensorsFloat.add(Tensor.createCopyOfTensorInWantedDataType( tt, new FloatType() ));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,12 @@
package io.bioimage.modelrunner.transformations;

import io.bioimage.modelrunner.tensor.Tensor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;

/**
* Abstract classes for tensor transformations where a new pixel value can be
Expand Down Expand Up @@ -69,7 +61,7 @@ public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply(
public < R extends RealType< R > & NativeType< R > >
void applyInPlace( final Tensor< R > input )
{
if (input.getData().getAt(0) instanceof IntegerType && dun != null) {
if (Util.getTypeFromInterval(input.getData()) instanceof IntegerType && dun != null) {
LoopBuilder
.setImages( input.getData() )
.multiThreaded()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,47 +233,5 @@ void scaleLinear(RandomAccessibleInterval<R> rai, double gain, double offset) {
.multiThreaded()
.forEachPixel( i -> i.setReal((i.getRealDouble() * gain + offset) ) );
}
/**
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((byte) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof UnsignedByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<UnsignedByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof ShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<ShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((short) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof UnsignedShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof IntType) {
LoopBuilder.setImages((RandomAccessibleInterval<IntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof UnsignedIntType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedIntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof LongType) {
LoopBuilder.setImages((RandomAccessibleInterval<LongType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof FloatType) {
LoopBuilder.setImages((RandomAccessibleInterval<FloatType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((float) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof DoubleType) {
LoopBuilder.setImages((RandomAccessibleInterval<DoubleType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((double) (i.get() * gain + offset) ) );
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,75 +152,6 @@ private < R extends RealType< R > & NativeType< R > > void globalScale( final Te
cursor.next();
flatArr[count ++] = cursor.get().getRealDouble();
}
/*
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof UnsignedByteType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof ShortType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof UnsignedShortType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof IntType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof UnsignedIntType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof LongType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof FloatType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof DoubleType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
Arrays.sort(flatArr);

int percentilePos = (int) (flatSize * percentile);
Expand Down Expand Up @@ -321,47 +252,5 @@ void scaleRange(RandomAccessibleInterval<R> rai, double maxPercentileVal, double
.multiThreaded()
.forEachPixel( i -> i.setReal(((i.getRealDouble() - minPercentileVal) / (diff + eps)) ) );
}
/**
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((byte) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<UnsignedByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof ShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<ShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((short) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof IntType) {
LoopBuilder.setImages((RandomAccessibleInterval<IntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedIntType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedIntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof LongType) {
LoopBuilder.setImages((RandomAccessibleInterval<LongType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof FloatType) {
LoopBuilder.setImages((RandomAccessibleInterval<FloatType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((float) ((i.get() - minPercentileVal) / (diff + eps))) );
} else if (rai.getAt(0) instanceof DoubleType) {
LoopBuilder.setImages((RandomAccessibleInterval<DoubleType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((double) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -389,47 +389,5 @@ void zeroMeanUnitVariance(RandomAccessibleInterval<R> rai, double mean, double s
.multiThreaded()
.forEachPixel( i -> i.setReal(((i.getRealDouble() - mean) / (std + eps)) ) );
}
/**
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((byte) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<UnsignedByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof ShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<ShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((short) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof IntType) {
LoopBuilder.setImages((RandomAccessibleInterval<IntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedIntType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedIntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof LongType) {
LoopBuilder.setImages((RandomAccessibleInterval<LongType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof FloatType) {
LoopBuilder.setImages((RandomAccessibleInterval<FloatType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((float) ((i.get() - mean) / (std + eps))) );
} else if (rai.getAt(0) instanceof DoubleType) {
LoopBuilder.setImages((RandomAccessibleInterval<DoubleType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((double) ((i.get() - mean) / (std + eps)) ) );
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
}
}

0 comments on commit b82c705

Please sign in to comment.