diff --git a/src/main/java/io/bioimage/modelrunner/pytorch/tensor/NDArrayBuilder.java b/src/main/java/io/bioimage/modelrunner/pytorch/tensor/NDArrayBuilder.java index b7db0f7..e2177fb 100644 --- a/src/main/java/io/bioimage/modelrunner/pytorch/tensor/NDArrayBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/pytorch/tensor/NDArrayBuilder.java @@ -27,11 +27,13 @@ import net.imglib2.Cursor; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.Img; -import net.imglib2.type.Type; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.ByteType; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; import net.imglib2.util.Util; import net.imglib2.view.Views; @@ -60,26 +62,9 @@ public class NDArrayBuilder { * @return The {@link NDArray} built from the {@link Tensor}. * @throws IllegalArgumentException If the tensor type is not supported. */ - public static NDArray build(Tensor tensor, NDManager manager) - throws IllegalArgumentException - { - // Create an Icy sequence of the same type of the tensor - if (Util.getTypeFromInterval(tensor.getData()) instanceof ByteType) { - return buildFromTensorByte(tensor.getData(), manager); - } - else if (Util.getTypeFromInterval(tensor.getData()) instanceof IntType) { - return buildFromTensorInt(tensor.getData(), manager); - } - else if (Util.getTypeFromInterval(tensor.getData()) instanceof FloatType) { - return buildFromTensorFloat(tensor.getData(), manager); - } - else if (Util.getTypeFromInterval(tensor.getData()) instanceof DoubleType) { - return buildFromTensorDouble(tensor.getData(), manager); - } - else { - throw new IllegalArgumentException("Unsupported tensor type: " + tensor - .getDataType()); - } + public static & NativeType> + NDArray build(Tensor tensor, NDManager manager) throws IllegalArgumentException { + return build(tensor.getData(), manager); } /** @@ -94,25 +79,21 @@ else if (Util.getTypeFromInterval(tensor.getData()) instanceof DoubleType) { * @return The {@link NDArray} built from the {@link RandomAccessibleInterval}. * @throws IllegalArgumentException if the {@link RandomAccessibleInterval} is not supported */ - public static > NDArray build( - RandomAccessibleInterval tensor, NDManager manager) + public static & NativeType> + NDArray build(RandomAccessibleInterval tensor, NDManager manager) throws IllegalArgumentException { if (Util.getTypeFromInterval(tensor) instanceof ByteType) { - return buildFromTensorByte((RandomAccessibleInterval) tensor, - manager); + return buildFromTensorByte(Cast.unchecked(tensor), manager); } else if (Util.getTypeFromInterval(tensor) instanceof IntType) { - return buildFromTensorInt((RandomAccessibleInterval) tensor, - manager); + return buildFromTensorInt(Cast.unchecked(tensor), manager); } else if (Util.getTypeFromInterval(tensor) instanceof FloatType) { - return buildFromTensorFloat((RandomAccessibleInterval) tensor, - manager); + return buildFromTensorFloat(Cast.unchecked(tensor), manager); } else if (Util.getTypeFromInterval(tensor) instanceof DoubleType) { - return buildFromTensorDouble( - (RandomAccessibleInterval) tensor, manager); + return buildFromTensorDouble(Cast.unchecked(tensor), manager); } else { throw new IllegalArgumentException("Unsupported tensor type: " + Util @@ -120,16 +101,6 @@ else if (Util.getTypeFromInterval(tensor) instanceof DoubleType) { } } - /** - * Builds a {@link NDArray} from a signed byte-typed - * {@link RandomAccessibleInterval}. - * - * @param tensor - * the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray} - * @param manager - * {@link NDManager} needed to create a {@link NDArray} - * @return The {@link NDArray} built from the tensor of type {@link ByteType}. - */ private static NDArray buildFromTensorByte( RandomAccessibleInterval tensor, NDManager manager) { @@ -156,16 +127,6 @@ private static NDArray buildFromTensorByte( return ndarray; } - /** - * Builds a {@link NDArray} from a signed integer-typed - * {@link RandomAccessibleInterval}. - * - * @param tensor - * the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray} - * @param manager - * {@link NDManager} needed to create a {@link NDArray} - * @return The {@link NDArray} built from the tensor of type {@link IntType}. - */ private static NDArray buildFromTensorInt( RandomAccessibleInterval tensor, NDManager manager) { @@ -192,16 +153,6 @@ private static NDArray buildFromTensorInt( return ndarray; } - /** - * Builds a {@link NDArray} from a signed float-typed - * {@link RandomAccessibleInterval}. - * - * @param tensor - * the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray} - * @param manager - * {@link NDManager} needed to create a {@link NDArray} - * @return The {@link NDArray} built from the tensor of type {@link FloatType}. - */ private static NDArray buildFromTensorFloat( RandomAccessibleInterval tensor, NDManager manager) { @@ -228,16 +179,6 @@ private static NDArray buildFromTensorFloat( return ndarray; } - /** - * Builds a {@link NDArray} from a signed double-typed - * {@link RandomAccessibleInterval}. - * - * @param tensor - * the {@link RandomAccessibleInterval} that will be copied into an {@link NDArray} - * @param manager - * {@link NDManager} needed to create a {@link NDArray} - * @return The {@link NDArray} built from the tensor of type {@link DoubleType}. - */ private static NDArray buildFromTensorDouble( RandomAccessibleInterval tensor, NDManager manager) {