From 5893e56a485cb8418ce81b1c08a1023c67dd7610 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Mon, 25 Mar 2024 20:11:39 +0100 Subject: [PATCH] improve robustness --- .../v2/api050/tensor/ImgLib2Builder.java | 19 +++++++++++++++++- .../v2/api050/tensor/TensorBuilder.java | 20 +++++++++---------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/ImgLib2Builder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/ImgLib2Builder.java index 3639e2d..736065b 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/ImgLib2Builder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/ImgLib2Builder.java @@ -21,7 +21,7 @@ package io.bioimage.modelrunner.tensorflow.v2.api050.tensor; import io.bioimage.modelrunner.tensor.Utils; - +import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.array.ArrayImgs; import net.imglib2.type.Type; @@ -31,6 +31,8 @@ import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; +import java.util.Arrays; + import org.tensorflow.Tensor; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -103,6 +105,9 @@ else if (tensor instanceof TInt64) private static RandomAccessibleInterval buildFromTensorUByte(TUint8 tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 1)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -123,6 +128,9 @@ private static RandomAccessibleInterval buildFromTensorUByte(T private static RandomAccessibleInterval buildFromTensorInt(TInt32 tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -143,6 +151,9 @@ private static RandomAccessibleInterval buildFromTensorInt(TInt32 tenso private static RandomAccessibleInterval buildFromTensorFloat(TFloat32 tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 4)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -163,6 +174,9 @@ private static RandomAccessibleInterval buildFromTensorFloat(TFloat32 private static RandomAccessibleInterval buildFromTensorDouble(TFloat64 tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; @@ -183,6 +197,9 @@ private static RandomAccessibleInterval buildFromTensorDouble(TFloat private static RandomAccessibleInterval buildFromTensorLong(TInt64 tensor) { long[] arrayShape = tensor.shape().asArray(); + if (CommonUtils.int32Overflows(arrayShape, 8)) + throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) + + " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8); long[] tensorShape = new long[arrayShape.length]; for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i]; int totalSize = 1; diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java index 610b66f..83577fe 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java @@ -134,9 +134,9 @@ public static TUint8 buildUByte(RandomAccessibleInterval tenso throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 1)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -171,9 +171,9 @@ public static TInt32 buildInt(RandomAccessibleInterval tensor) throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 4)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -209,9 +209,9 @@ private static TInt64 buildLong(RandomAccessibleInterval tensor) throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 8)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per long tensor supported: " + Integer.MAX_VALUE / 8); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -248,9 +248,9 @@ public static TFloat32 buildFloat( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 4)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; @@ -286,9 +286,9 @@ private static TFloat64 buildDouble( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); - if (CommonUtils.int32Overflows(ogShape)) + if (CommonUtils.int32Overflows(ogShape, 8)) throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) - + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); + + " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8); tensor = Utils.transpose(tensor); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1;