diff --git a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp index 54fd8fb0e..374456be6 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp @@ -26,24 +26,38 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(lin_space, 3, 1, false, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - auto start = INPUT_VARIABLE(0); - auto finish = INPUT_VARIABLE(1); - auto numOfElements = INPUT_VARIABLE(2); + CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { - if (numOfElements->e(0) == 1) { + auto output = OUTPUT_VARIABLE(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); + + REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT()); + + auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) : static_cast(T_ARG(0)); + auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) : static_cast(T_ARG(1)); + auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); + + if (numOfElements == 1) { output->assign(start); return Status::OK(); } - output->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); + output->linspace(start, (finish - start) / ( numOfElements - 1.0 )); return Status::OK(); } DECLARE_SHAPE_FN(lin_space) { - auto dataType = ArrayOptions::dataType(inputShape->at(0)); - Nd4jLong steps = INPUT_VARIABLE(2)->e(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); + REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() ); + + + auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast(D_ARG(0)) : DataType::FLOAT32) ; + Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index f3131c193..8fae1b63c 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1433,16 +1433,20 @@ namespace sd { /** * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) * - * input params: + * optional input params: * 0 - startVal - NDArray scalar (float point) * 1 - finishVal - NDArray scalar (float point) * 2 - numOfElements - NDArray scalar (integer) - * + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements * output: * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. */ #if NOT_EXCLUDED(OP_lin_space) - DECLARE_CUSTOM_OP(lin_space, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); #endif /** diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 03e5ae53f..6d89bd182 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2010,6 +2010,34 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) { ASSERT_TRUE(expect.equalsTo(res)); +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test2) { + + NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, + 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 ); + ASSERT_TRUE(expect.equalsTo(res)); + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test3) { + + NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE ); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + + ASSERT_EQ( res->dataType(), expect.dataType()); + ASSERT_TRUE(expect.equalsTo(res)); + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 6f559230b..29c681544 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1334,6 +1334,20 @@ TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { ASSERT_EQ(Status::OK(), status); } +TEST_F(JavaInteropTests, test_linspace_shape_1) { + if (!Environment::getInstance()->isCPU()) + return; + + sd::ops::lin_space op; + double tArgs[2] = {1.0, 10.0}; + Nd4jLong iArgs = 10L; + int dArg = (int) sd::DataType::FLOAT32; + auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); + + ASSERT_EQ(1, result->size()); + delete result; +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index 4bc3b3f63..a9f964844 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -42,6 +42,9 @@ import java.util.Map; public class Linspace extends DynamicCustomOp { private DataType dataType; + private double start; + private double stop; + private long elements; public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType); @@ -54,7 +57,7 @@ public class Linspace extends DynamicCustomOp { } public Linspace(DataType dataType, double start, double stop, long number) { - this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number)); + this(start, stop, number, dataType); } public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) { @@ -67,6 +70,19 @@ public class Linspace extends DynamicCustomOp { addDArgument(dataType); } + public Linspace(double start, double stop, long number, @NonNull DataType dataType) { + super(new INDArray[]{}, null); + this.dataType = dataType; + addDArgument(dataType); + + this.start = start; + this.stop = stop; + this.elements = number; + + addTArgument(this.start, this.stop); + addIArgument(elements); + } + public Linspace(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index bda208ce7..b21123085 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1947,7 +1947,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val result = new ArrayList(); int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); - if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { + if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index f0488636f..93ff5cf52 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1754,7 +1754,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val result = new ArrayList(); int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); - if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { + if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index adddf5e42..cd18e0f18 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -20475,11 +20475,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) * - * input params: + * optional input params: * 0 - startVal - NDArray scalar (float point) * 1 - finishVal - NDArray scalar (float point) * 2 - numOfElements - NDArray scalar (integer) - * + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements * output: * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 01dc83ee4..04ebaa0d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; +import org.nd4j.linalg.api.ops.impl.shape.Linspace; import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; import org.nd4j.linalg.api.ops.impl.transforms.Cholesky; @@ -1803,6 +1804,16 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(ret[0], in); } + + @Test + public void testLinspaceSignature_1() throws Exception { + val array1 = Nd4j.exec(new Linspace(DataType.FLOAT, Nd4j.scalar(1.0f), Nd4j.scalar(10.f), Nd4j.scalar(10L)))[0]; + val array2 = Nd4j.exec(new Linspace(DataType.FLOAT, 1.0f, 10.f, 10L))[0]; + + assertEquals(array1.dataType(), array2.dataType()); + assertEquals(array1, array2); + } + @Test public void testLogdet() { INDArray x = Nd4j.createFromArray(new double[]{