Skip to content

Commit

Permalink
improve robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Mar 25, 2024
1 parent 8abcf8b commit 5893e56
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
package io.bioimage.modelrunner.tensorflow.v2.api050.tensor;

import io.bioimage.modelrunner.tensor.Utils;

import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.Type;
Expand All @@ -31,6 +31,8 @@
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;

import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
Expand Down Expand Up @@ -103,6 +105,9 @@ else if (tensor instanceof TInt64)
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(TUint8 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 1))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -123,6 +128,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -143,6 +151,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tenso
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 4))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -163,6 +174,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat64 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand All @@ -183,6 +197,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat
private static RandomAccessibleInterval<LongType> buildFromTensorLong(TInt64 tensor)
{
long[] arrayShape = tensor.shape().asArray();
if (CommonUtils.int32Overflows(arrayShape, 8))
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ public static TUint8 buildUByte(RandomAccessibleInterval<UnsignedByteType> tenso
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 1))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -171,9 +171,9 @@ public static TInt32 buildInt(RandomAccessibleInterval<IntType> tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 4))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -209,9 +209,9 @@ private static TInt64 buildLong(RandomAccessibleInterval<LongType> tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 8))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per long tensor supported: " + Integer.MAX_VALUE / 8);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -248,9 +248,9 @@ public static TFloat32 buildFloat(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 4))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down Expand Up @@ -286,9 +286,9 @@ private static TFloat64 buildDouble(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
if (CommonUtils.int32Overflows(ogShape, 8))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
+ " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8);
tensor = Utils.transpose(tensor);
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
Expand Down

0 comments on commit 5893e56

Please sign in to comment.