From b786418c5df9fcbbf71b8e1c400b126a284c1a75 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 11 May 2020 21:29:52 +1000 Subject: [PATCH] Fix an issue when creating DataBuffer/INDArray from ByteBuffer for multiple datatypes (#446) * Fix missing dtypes when creating DataBuffer from ByteBuffer Signed-off-by: Alex Black * Revert LongIndexer -> ULongIndexer; fixes for UIntIndexer Signed-off-by: Alex Black * CUDA fix Signed-off-by: Alex Black --- .../linalg/api/buffer/BaseDataBuffer.java | 25 +++++++++++++++---- .../jcublas/buffer/BaseCudaDataBuffer.java | 10 ++++++-- .../nativecpu/buffer/BaseCpuDataBuffer.java | 24 +++++++++++++++--- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 19 ++++++++++++++ 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 212d55cd8..f2fb7d382 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case FLOAT: return ((FloatIndexer) indexer).get(i); case UINT32: + return ((UIntIndexer) indexer).get(i); case INT: return ((IntIndexer) indexer).get(i); case BFLOAT16: @@ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer { return (long) ((Bfloat16Indexer) indexer).get(i); case HALF: return (long) ((HalfIndexer) indexer).get( i); - case UINT64: + case UINT64: //Fall through case LONG: return ((LongIndexer) indexer).get(i); case UINT32: + return (long) ((UIntIndexer) indexer).get(i); case INT: return (long) ((IntIndexer) indexer).get(i); case UINT16: @@ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case BOOL: return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0); case UINT32: + return (short) ((UIntIndexer)indexer).get(i); case INT: return (short) ((IntIndexer) indexer).get(i); case UINT16: @@ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case BOOL: return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f; case UINT32: + return (float) ((UIntIndexer)indexer).get(i); case INT: return (float) ((IntIndexer) indexer).get(i); case UINT16: @@ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer { return (float) ((UByteIndexer) indexer).get(i); case BYTE: return (float) ((ByteIndexer) indexer).get(i); - case UINT64: + case UINT64: //Fall through case LONG: return (float) ((LongIndexer) indexer).get(i); case FLOAT: @@ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case BOOL: return ((BooleanIndexer) indexer).get(i) ? 1 : 0; case UINT32: + return (int)((UIntIndexer) indexer).get(i); case INT: return ((IntIndexer) indexer).get(i); case BFLOAT16: @@ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer { return ((UByteIndexer) indexer).get(i); case BYTE: return ((ByteIndexer) indexer).get(i); - case UINT64: + case UINT64: //Fall through case LONG: return (int) ((LongIndexer) indexer).get(i); case FLOAT: @@ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer { ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: + ((UIntIndexer) indexer).put(i, (long)element); + break; case INT: ((IntIndexer) indexer).put(i, (int) element); break; @@ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer { ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: + ((UIntIndexer) indexer).put(i, (long)element); + break; case INT: ((IntIndexer) indexer).put(i, (int) element); break; @@ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer { ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: + ((UIntIndexer) indexer).put(i, element); + break; case INT: ((IntIndexer) indexer).put(i, element); break; - case UINT64: + case UINT64: //Fall through case LONG: ((LongIndexer) indexer).put(i, element); break; @@ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer { case SHORT: ((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0); break; - case INT: case UINT32: + ((UIntIndexer) indexer).put(i, element ? 1 : 0); + break; + case INT: ((IntIndexer) indexer).put(i, element ? 1 : 0); break; case UINT64: @@ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer { ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: + ((UIntIndexer) indexer).put(i, element); + break; case INT: ((IntIndexer) indexer).put(i, (int) element); break; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 5f1cfb99e..14e64df61 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda indexer = FloatIndexer.create((FloatPointer) pointer); break; case UINT32: + this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); + indexer = UIntIndexer.create((IntPointer) pointer); + break; case INT: this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); @@ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); break; - case UINT64: + case UINT64: //Fall through case LONG: this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); @@ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda indexer = FloatIndexer.create((FloatPointer) pointer); break; case UINT32: + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); + indexer = UIntIndexer.create((IntPointer) pointer); + break; case INT: this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); @@ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); break; - case UINT64: + case UINT64: //Fall through case LONG: this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index dd9bb1af1..7a2a8467a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if(dataType() == DataType.FLOAT16){ + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + } else if(dataType() == DataType.BFLOAT16){ + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if(dataType() == DataType.BOOL){ + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer(); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if(dataType() == DataType.UINT16){ + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + } else if(dataType() == DataType.UINT32){ + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + setIndexer(UIntIndexer.create((IntPointer) pointer)); + } else if (dataType() == DataType.UINT64) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); } Nd4j.getDeallocatorService().pickObject(this); @@ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo } else if (dataType() == DataType.UINT32) { pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); - // FIXME: we need unsigned indexer here - setIndexer(IntIndexer.create((IntPointer) pointer)); + setIndexer(UIntIndexer.create((IntPointer) pointer)); if (initialize) fillPointerWithZero(); } else if (dataType() == DataType.UINT64) { pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); - // FIXME: we need unsigned indexer here setIndexer(LongIndexer.create((LongPointer) pointer)); if (initialize) @@ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo // FIXME: need unsigned indexer here pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); + setIndexer(UIntIndexer.create((IntPointer) pointer)); } else if (dataType() == DataType.UINT64) { attached = true; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 77df067ca..da8983118 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(e, z); } + @Test + public void testCreateBufferFromByteBuffer(){ + + for(DataType dt : DataType.values()){ + if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN) + continue; +// System.out.println(dt); + + int lengthBytes = 256; + int lengthElements = lengthBytes / dt.width(); + ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes); + + DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0); + INDArray arr = Nd4j.create(db, new long[]{lengthElements}); + + arr.toStringFull(); + } + } + @Override public char ordering() { return 'c';