Skip to content

Commit

Permalink
correct same errors that affected javacpp
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 25, 2024
1 parent a7c7d09 commit c5fce53
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;

import org.tensorflow.types.TFloat32;
Expand Down Expand Up @@ -104,7 +105,7 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
buff.put(buff2);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -120,7 +121,7 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
buff.put(buff2);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -134,9 +135,14 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
ByteBuffer buff = shma.getDataBufferNoHeader();
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
tensor.asRawTensor().data().asFloats().read(buff2.asFloatBuffer().array(), 0, buff.capacity());
buff.put(buff2);

float sum = 0;
for (float ff : buff2.asFloatBuffer().array())
sum += ff;
System.out.println("SECOND SUM: " + sum);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -152,7 +158,7 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
buff.put(buff2);
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -169,7 +175,7 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
buff.put(buff2);
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
float[] flat = new float[buff.capacity() / 4];
buff.asFloatBuffer().get(flat);
buff.rewind();
float sum = 0;
for (float ff : flat)
sum += ff;
System.out.println(sum);
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand Down

0 comments on commit c5fce53

Please sign in to comment.