Skip to content

Commit

Permalink
correct errors in interprocessing communication
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 2, 2024
1 parent 6375b7c commit 42bc209
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,8 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 1;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff1 = shma.getDataBufferNoHeader();
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -116,13 +111,8 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 4;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff1 = shma.getDataBufferNoHeader();
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -134,13 +124,8 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 4;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff1 = shma.getDataBufferNoHeader();
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -152,13 +137,8 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 8;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff1 = shma.getDataBufferNoHeader();
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -171,13 +151,8 @@ private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName


SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
ByteBuffer buff = shma.getDataBuffer();
int totalSize = 8;
for (long i : arrayShape) {totalSize *= i;}
byte[] flatArr = new byte[buff.capacity()];
buff.get(flatArr);
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
shma.setBuffer(ByteBuffer.wrap(flatArr));
ByteBuffer buff1 = shma.getDataBufferNoHeader();
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@

import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;

import java.nio.ByteBuffer;
Expand Down Expand Up @@ -125,9 +118,7 @@ private static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
IntBuffer intBuff = buff.asIntBuffer();
int[] intArray = new int[intBuff.capacity()];
intBuff.get(intArray);
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false);
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
Tensor<TInt32> ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand All @@ -143,9 +134,7 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
LongBuffer longBuff = buff.asLongBuffer();
long[] longArray = new long[longBuff.capacity()];
longBuff.get(longArray);
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false);
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
Tensor<TInt64> ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand All @@ -161,9 +150,7 @@ private static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
FloatBuffer floatBuff = buff.asFloatBuffer();
float[] floatArray = new float[floatBuff.capacity()];
floatBuff.get(floatArray);
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false);
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
Tensor<TFloat32> ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand All @@ -179,9 +166,7 @@ private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
double[] doubleArray = new double[doubleBuff.capacity()];
doubleBuff.get(doubleArray);
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false);
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleBuff.array(), false);
Tensor<TFloat64> ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand Down

0 comments on commit 42bc209

Please sign in to comment.