From 4e8f3a025faea346bda2e43d18767100c3e59753 Mon Sep 17 00:00:00 2001 From: Shams Ul Azeem Date: Tue, 24 Mar 2020 13:11:57 +0500 Subject: [PATCH] Fixing python object for obtaining scalars (#330) * Fixing python object for obtaining scalars Signed-off-by: shams * Fix variable name for stridePtr Signed-off-by: shams * Fix variable name for stridePtr Signed-off-by: shams Co-authored-by: Alex Black --- .../java/org/datavec/python/PythonObject.java | 26 +++++----- .../datavec/python/ScalarAndArrayTest.java | 48 +++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java index 0408e3a59..4a6a617d5 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -77,7 +77,7 @@ public class PythonObject { long address = bp.address(); long size = bp.capacity(); - NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build(); + NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.INT8).build(); nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject; } @@ -320,20 +320,23 @@ public class PythonObject { public NumpyArray toNumpy() throws PythonException{ PyObject np = PyImport_ImportModule("numpy"); PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); - if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){ + if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){ throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); } Py_DecRef(ndarray); Py_DecRef(np); + Pointer objPtr = new Pointer(nativePythonObject); PyArrayObject npArr = new PyArrayObject(objPtr); Pointer ptr = PyArray_DATA(npArr); - SizeTPointer shapePtr = PyArray_SHAPE(npArr); long[] shape = new long[PyArray_NDIM(npArr)]; - shapePtr.get(shape, 0, shape.length); - SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + SizeTPointer shapePtr = PyArray_SHAPE(npArr); + if (shapePtr != null) + shapePtr.get(shape, 0, shape.length); long[] strides = new long[shape.length]; - stridesPtr.get(strides, 0, strides.length); + SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + if (stridesPtr != null) + stridesPtr.get(strides, 0, strides.length); int npdtype = PyArray_TYPE(npArr); DataType dtype; @@ -345,28 +348,27 @@ public class PythonObject { case NPY_SHORT: dtype = DataType.SHORT; break; case NPY_INT: - dtype = DataType.INT; break; + dtype = DataType.INT32; break; case NPY_LONG: dtype = DataType.LONG; break; case NPY_UINT: dtype = DataType.UINT32; break; case NPY_BYTE: - dtype = DataType.BYTE; break; + dtype = DataType.INT8; break; case NPY_UBYTE: - dtype = DataType.UBYTE; break; + dtype = DataType.UINT8; break; case NPY_BOOL: dtype = DataType.BOOL; break; case NPY_HALF: - dtype = DataType.HALF; break; + dtype = DataType.FLOAT16; break; case NPY_LONGLONG: dtype = DataType.INT64; break; case NPY_USHORT: dtype = DataType.UINT16; break; case NPY_ULONG: - dtype = DataType.UINT64; break; case NPY_ULONGLONG: dtype = DataType.UINT64; break; - default: + default: throw new PythonException("Unsupported array data type: " + npdtype); } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java b/datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java new file mode 100644 index 000000000..e6b1bf606 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; + +@RunWith(Parameterized.class) +public class ScalarAndArrayTest { + + @Parameterized.Parameters(name = "{index}: Testing with INDArray={0}") + public static INDArray[] data() { + return new INDArray[]{ + Nd4j.scalar(10), + Nd4j.ones(10, 10, 10, 10) + }; + } + + private INDArray indArray; + + public ScalarAndArrayTest(INDArray indArray) { + this.indArray = indArray; + } + + @Test + public void testINDArray() throws PythonException { + assertEquals(indArray, new PythonObject(indArray).toNumpy().getNd4jArray()); + } +}