diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index fb5acf19b..f1fda4a9a 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -53,6 +53,7 @@ inline pairwise::Ops fromBroadcastToPairwise(broadcast::Ops op) { case broadcast::LogicalXor: return pairwise::LogicalXor; case broadcast::LogicalNot: return pairwise::LogicalNot; case broadcast::LogicalAnd: return pairwise::LogicalAnd; + case broadcast::PowDerivative: return pairwise::PowDerivative; default: throw std::runtime_error("fromBroadcastToPairwise: Not convertible operation"); } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 7de54a858..ea32b154c 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -80,7 +80,8 @@ (30, LogicalAnd), \ (31, DivideNoNan), \ (32, IGamma), \ - (33, IGammac) + (33, IGammac),\ + (34, PowDerivative) // these ops return same data type as input #define TRANSFORM_SAME_OPS \ diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 256e37341..1bcd2df8b 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -52,6 +52,9 @@ namespace nd4j { static BroadcastOpsTuple Subtract(); static BroadcastOpsTuple IGamma(); static BroadcastOpsTuple IGammac(); + + static BroadcastOpsTuple Pow(); + static BroadcastOpsTuple PowDerivative(); }; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp index 7f7efd80c..56f77737d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp @@ -16,6 +16,7 @@ // // @author raver119@gmail.com +// @author Oleh Semeniv (oleg.semeniv@gmail.com) // #include @@ -25,7 +26,7 @@ #include namespace nd4j { - namespace ops { +namespace ops { BROADCASTABLE_OP_IMPL(Pow, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); @@ -51,7 +52,76 @@ namespace nd4j { ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); } - } + + CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto dLdz = INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdy = OUTPUT_VARIABLE(1); + + Nd4jLong* dLdzShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.getWorkspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, "POW_BP OP: the shapes of x %s" + " and y %s are not suitable for broadcast !", + ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, + "POW_BP OP: wrong shape of next epsilon array (dLdOut)," + " expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); + + // dL/dy = x^y * log(x) * dL/dz + auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y + x->applyTransform(transform::Log, *dLdx); // b = log(x) + dLdx->applyScalar(nd4j::scalar::ReplaceNans, 0, *dLdx); + temp *= *dLdx; // c = b*a + temp *= *dLdz; // dL/dy = c * dL/dz + if (dLdy->isSameShape(*dLdz)) { + dLdy->assign(temp); + } + else { + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); + dLdy->assign(temp.reduceAlongDimension(reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz) + } + + // dL/dx = y*x^(y-1) * dL/dz + x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, temp); // a = y*x^(y-1) + temp *= *dLdz; // dLdx = a*dL/dz + + if (dLdx->isSameShape(*dLdz)) { + dLdx->assign(temp); // dLdx = a*dL/dz + } + else { + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); + dLdx->assign(temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz + } + + return Status::OK(); + } + + DECLARE_SHAPE_FN(Pow_bp) { + + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); + + Nd4jLong* dLdxShapeInfo = nullptr; + Nd4jLong* dLdyShapeInfo = nullptr; + + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + COPY_SHAPE(yShapeInfo, dLdyShapeInfo); + + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); + } + + DECLARE_TYPES(Pow_bp) { + getOpDescriptor() + ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) + ->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS + } + +} } #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index 7ee53b52a..9a2dc9f62 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -356,6 +356,7 @@ namespace nd4j { */ #if NOT_EXCLUDED(OP_Pow) DECLARE_BROADCASTABLE_OP(Pow, 0, 0); + DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0); #endif /** diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index 0e9c99636..26cda74a4 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -55,4 +55,12 @@ namespace nd4j { return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac); } + + BroadcastOpsTuple BroadcastOpsTuple::Pow() { + return custom(nd4j::scalar::Pow, nd4j::pairwise::Pow, nd4j::broadcast::Pow); + } + BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() { + return custom(nd4j::scalar::PowDerivative, nd4j::pairwise::PowDerivative, nd4j::broadcast::PowDerivative); + } + } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 0d32daebd..96234f41d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -1279,4 +1279,335 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) { ASSERT_TRUE(expected.equalsTo(output)); delete result; -} \ No newline at end of file +} + +//////////////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { + + // same shape + NDArray x('c', { 2,2,2 }, { 4,3,2,5,7,8,-9,-12 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 2,2,2 }, { 2,3,-2,4,-1,-4,10,8 }, nd4j::DataType::FLOAT32); + + + NDArray dLdz('c', { 2,2,2 }, nd4j::DataType::FLOAT32); + NDArray dLdxExp('c', { 2,2,2 }, { 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 2,2,2 }, { 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 }, nd4j::DataType::FLOAT32); + + dLdz.assign(1.0); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { + + NDArray x('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 3,2,1 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExp('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24., 26.4, 28.8 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 3,2,1 }, { 13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162 }, nd4j::DataType::FLOAT32); + + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; + +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { + + // y - same shape as dLdz + NDArray xY('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray yY('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExpY('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24. , 26.4, 28.8 }, nd4j::DataType::FLOAT32); + NDArray dLdyExpY('c', { 3,2,3 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, 15.5265 , 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + xY.assign(4.0); + yY.assign(2.0); + dLdz.linspace(0.1, 0.1); + + nd4j::ops::Pow_bp op; + auto resultsY = op.execute({ &xY, &yY, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsY->status()); + + auto* dLdxY = resultsY->at(0); + auto* dLdyY = resultsY->at(1); + + ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); + + delete resultsY; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { + + // x - same shape ad dLdz + NDArray yX('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray xX('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExpX('c', { 3,2,3 }, { 3.2, 6.4, 9.6, 12.8, 16. , 19.2, 22.4, 25.6, 28.8, 32. , 35.2, 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6 }, nd4j::DataType::FLOAT32); + NDArray dLdyExpX('c', { 1,2,3 }, { 23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528 }, nd4j::DataType::FLOAT32); + + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + dLdz.linspace(0.1, 0.1); + + nd4j::ops::Pow_bp op; + + xX.assign(2.0); + yX.assign(4.0); + + auto resultsX = op.execute({ &xX, &yX, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsX->status()); + + auto* dLdxX = resultsX->at(0); + auto* dLdyX = resultsX->at(1); + + ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); + ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); + ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); + ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); + + delete resultsX; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { + + // both single array + NDArray xConst('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray yConst('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray dLdxExp('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 1 }, nd4j::DataType::FLOAT32); + + xConst.assign(3.0); + yConst.assign(4.0); + dLdz.assign(1.0); + + dLdxExp.assign(4.0 * pow(3, 3)); + dLdyExp.assign(pow(3, 4) * log(3)); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &xConst, &yConst, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { + + // x single array + NDArray xConst('c', { 1 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + + xConst.assign(2.0); + y.assign(4.0); + dLdzC.linspace(0.1, 0.1); + + NDArray dLdxExpXC('c', { 1 }, { 115.2 }, nd4j::DataType::FLOAT32); + NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32); + + nd4j::ops::Pow_bp op; + auto resultsXC = op.execute({ &xConst, &y, &dLdzC }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status()); + + auto* dLdxXC = resultsXC->at(0); + auto* dLdyXC = resultsXC->at(1); + + ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); + ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); + ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); + ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); + + delete resultsXC; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { + + // Y - scalar + auto Y = NDArrayFactory::create(2.f); + NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); + + dLdzC.linspace(0.1, 0.1); + x = 4.f; + + NDArray dLdxExpYs('c', { 2, 2, 2 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4 }, nd4j::DataType::FLOAT32); + + auto dLdyExpYs = NDArrayFactory::create(79.85056f); + + nd4j::ops::Pow_bp op; + auto resultsYs = op.execute({ &x, &Y, &dLdzC }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status()); + + auto* dLdxY = resultsYs->at(0); + auto* dLdyY = resultsYs->at(1); + + ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); + + delete resultsYs; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { + // both scalars + + auto X = NDArrayFactory::create(4.f); + auto Y = NDArrayFactory::create(2.f); + NDArray dLdz = NDArrayFactory::create(0.1f); + + NDArray dLdxExp = NDArrayFactory::create(2.f*4.f*0.1f); + + NDArray dLdyExp = NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); + + nd4j::ops::Pow_bp op; + auto results = op.execute({ &X, &Y, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { + + nd4j::ops::Pow_bp op; + // diff shapes + NDArray x('c', { 3,2,1 }, nd4j::DataType::FLOAT32); + NDArray y('c', { 1,2,3 }, nd4j::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExp('c', { 3,2,1 }, { 4.8, 12., 19.2, 26.4, 33.6, 40.8 }, nd4j::DataType::FLOAT32); + NDArray dLdyExp('c', { 1,2,3 }, { 46.57949, 53.2337 , 59.88792, 66.54213, 73.19634, 79.85056 }, nd4j::DataType::FLOAT32); + + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); + + auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto* dLdx = results->at(0); + auto* dLdy = results->at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + + delete results; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { + + // diff shapes broadcastable + NDArray yB('c', { 1,2,3,1 }, nd4j::DataType::FLOAT32); + NDArray xB('c', { 2,3,1 }, nd4j::DataType::FLOAT32); + + NDArray dLdyExpB('c', { 1,2,3,1 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843 }, nd4j::DataType::FLOAT32); + NDArray dLdxExpB('c', { 2,3,1 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8 }, nd4j::DataType::FLOAT32); + NDArray dLdzB('c', { 1,2,3,1 }, nd4j::DataType::FLOAT32); + + dLdzB.linspace(0.1, 0.1); + xB.assign(4.0); + yB.assign(2.0); + + nd4j::ops::Pow_bp op; + auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); + + auto* dLdxB = resultsB->at(0); + auto* dLdyB = resultsB->at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); + + delete resultsB; +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { + + NDArray xB('c', { 3,2,1 }, { .4, 3, 5, .8, -9, -12 }, nd4j::DataType::FLOAT32); + NDArray yB('c', { 1,2,3 }, { 3, -2, .4, -4, 10, .8 }, nd4j::DataType::FLOAT32); + + NDArray dLdxExpB('c', { 3,2,1 }, { -5.994056, 39366.191406, 7.508829, -2.223537, -std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN() }, nd4j::DataType::FLOAT32); + NDArray dLdyExpB('c', { 1,2,3 }, { 20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, 12974.389648, -std::numeric_limits::quiet_NaN() }, nd4j::DataType::FLOAT32); + + NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32); + + nd4j::ops::Pow_bp op; + auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); + auto* dLdxB = resultsB->at(0); + auto* dLdyB = resultsB->at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + for (int i = 0; i < dLdxB->lengthOf(); ++i) { + if (!nd4j::math::nd4j_isnan(dLdxB->e(i)) && !nd4j::math::nd4j_isnan(dLdxExpB.e(i))) + ASSERT_NEAR(dLdxB->e(i), dLdxExpB.e(i), 0.00001); + } + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + for (int i = 0; i < dLdyB->lengthOf(); ++i) { + if (!nd4j::math::nd4j_isnan(dLdyB->e(i)) && !nd4j::math::nd4j_isnan(dLdyExpB.e(i))) + ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); + } + + delete resultsB; +} diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 0e320c726..0d5572ec6 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1090,7 +1090,7 @@ TEST_F(RNGTests, test_multinomial_5) { // multinomial as binomial if 2 classes used int batchValue = 1; int ClassValue = 2; - int Samples = 1000000; + int Samples = 100000; NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32); @@ -1107,8 +1107,8 @@ TEST_F(RNGTests, test_multinomial_5) { auto mean = output.meanNumber(); // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); // theoretical values for binomial - ASSERT_NEAR(0.5, deviation.e(0), 3e-3); - ASSERT_NEAR(0.5, mean.e(0), 3e-3); + ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); + ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); for (int i = 0; i < output.lengthOf(); i++) { auto value = output.e(i); @@ -1122,8 +1122,8 @@ TEST_F(RNGTests, test_multinomial_5) { deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); mean = outputR->meanNumber(); // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(0.5, deviation.e(0), 35e-3); - ASSERT_NEAR(0.5, mean.e(0), 35e-3); + ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); for (int i = 0; i < outputR->lengthOf(); i++) { auto value = outputR->e(i); @@ -1138,7 +1138,7 @@ TEST_F(RNGTests, test_multinomial_6) { int batchValue = 1; int ClassValue = 5; - int Samples = 1000000; + int Samples = 100000; NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32); @@ -1165,14 +1165,14 @@ TEST_F(RNGTests, test_multinomial_6) { auto c = countsR.e(i); auto p = probExpect.e(i); // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); - ASSERT_NEAR((c / Samples), p, 35e-3); + ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); } auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); auto mean = outputR->meanNumber(); // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(1.2175, deviation.e(0), 35e-3); - ASSERT_NEAR(2.906, mean.e(0), 35e-3); + ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); delete resultR; @@ -1195,12 +1195,12 @@ TEST_F(RNGTests, test_multinomial_6) { auto c = counts.e(i); auto p = probExpect.e(i); // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); - ASSERT_NEAR((c / Samples), p, 3e-3); + ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); } deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); mean = output.meanNumber(); // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(1.2175, deviation.e(0), 3e-3); - ASSERT_NEAR(2.906, mean.e(0), 3e-3); + ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); + ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 0a725786e..e38af27d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -92,20 +92,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.*; import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; @@ -1420,6 +1407,10 @@ public class DifferentialFunctionFactory { return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); } + public SDVariable[] powBp(SDVariable x, SDVariable pow, SDVariable gradient) { + return new PowBp(sameDiff(), x, pow, gradient).outputVariables(); + } + public SDVariable mishDerivative(SDVariable iX) { return new MishDerivative(sameDiff(), iX, false).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index e85c472c8..ad23d1266 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -230,6 +230,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scalar.LogX.class, org.nd4j.linalg.api.ops.impl.scalar.Pow.class, org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, + org.nd4j.linalg.api.ops.impl.reduce.bp.PowBp.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/PowBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/PowBp.java new file mode 100644 index 000000000..c46414f79 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/PowBp.java @@ -0,0 +1,45 @@ +package org.nd4j.linalg.api.ops.impl.reduce.bp; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.BaseArithmeticBackpropOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class PowBp extends BaseDynamicTransformOp { + + public PowBp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable dLdz) { + super(sameDiff,new SDVariable[]{x,y,dLdz}, false); + } + + public PowBp(INDArray x, INDArray y, INDArray dLdz, + INDArray dLdx, INDArray dLdy) { + super(new INDArray[]{x,y, dLdz}, new INDArray[]{dLdx, dLdy}); + } + + @Override + public String opName() { + return "Pow_bp"; + } + + @Override + public boolean isInplaceCall() { + return false; + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes); + //Gradient types: same as input + return Arrays.asList(arg(0).dataType(), arg(1).dataType()); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 8aafce3d1..08ead2683 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -19,7 +19,9 @@ package org.nd4j.linalg.api.ops.impl.scalar; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseTransformOp; @@ -29,6 +31,7 @@ import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -89,9 +92,8 @@ public class Pow extends BaseScalarOp { } @Override - public List doDiff(List i_v1) { + public List doDiff(List i_v1) { SDVariable g = f().powDerivative(arg(), this.pow).mul(i_v1.get(0)); return Arrays.asList(g); } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index df41438fe..e155a4f2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -62,11 +62,14 @@ public class Pow extends DynamicCustomOp { //dL/da = b*a^(b-1) * dL/dy //dL/db = a^b * log(a) * dL/dy - SDVariable a = arg(0); + /*SDVariable a = arg(0); SDVariable b = arg(1); SDVariable dlda = b.mul(sameDiff.math().pow(a,b.sub(1))).mul(f1.get(0)); SDVariable dldb = outputVariable().mul(sameDiff.math().log(a)).mul(f1.get(0)); - return Arrays.asList(dlda, dldb); + return Arrays.asList(dlda, dldb);*/ + + SDVariable[] g = f().powBp(arg(0), arg(1), f1.get(0)); + return Arrays.asList(g); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java index 58afb2acb..0fbfa2671 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java @@ -921,4 +921,60 @@ public class ReductionBpOpValidation extends BaseOpValidation { assertNull(err, err); } } + + @Test + public void testPowBP() { + + for (boolean keepDims : new boolean[]{false, true}) { + + INDArray preReduceInput_1 = Nd4j.createFromArray(new double[]{ + 4,3,2,5,7,8,-9,-12 + }).reshape(2,2,2); + INDArray preReduceInput_2 = Nd4j.createFromArray(new double[]{ + 2,3,-2,4,-1,-4,10,8 + }).reshape(2,2,2); + INDArray preReduceInput_3 = Nd4j.linspace(1, 8, 8).reshape(2, 2,2); + INDArray gradOutput = Nd4j.valueArrayOf(new long[]{2, 2, 2}, 1.0); + INDArray dLdInExpected_1 = Nd4j.createFromArray(new double[]{ + 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 + }).reshape(2,2,2); + INDArray dLdInExpected_2 = Nd4j.createFromArray(new double[]{ + 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 + }).reshape(2,2,2); + INDArray output1 = Nd4j.createUninitialized(2, 2,2); + INDArray output2 = Nd4j.createUninitialized(2, 2,2); + + String err = OpValidation.validate(new OpTestCase(new PowBp(preReduceInput_1, preReduceInput_2, + gradOutput, output1, output2)) + .expectedOutput(0, dLdInExpected_1).expectedOutput(1, dLdInExpected_2)); + + assertNull(err); + } + } + + @Test + public void testPowBP1() { + + INDArray preReduceInput_1 = Nd4j.createFromArray(new float[]{ + 0.0714f, 0.4735f, -0.1249f, 0.4482f, + -0.1376f, 0.5218f, 0.5558f, 0.2444f, + -0.5297f, 0.4291f, 0.4913f, -0.1178f + }).reshape(3,4); + INDArray preReduceInput_2 = Nd4j.scalar(2.0000f); + + INDArray gradOutput = Nd4j.valueArrayOf(new long[]{3, 4}, 1.0f); + + INDArray output1 = Nd4j.createUninitialized(DataType.FLOAT, 3,4); + INDArray output2 = Nd4j.scalar(DataType.FLOAT, 1.0); //Nd4j.createUninitialized(DataType.FLOAT, 3,4); + + INDArray expected1 = Nd4j.createFromArray(new float[]{ + 0.1428f, 0.9470f, -0.2498f, 0.8964f, + -0.2752f, 1.0436f, 1.1116f, 0.4888f, + -1.0594f, 0.8582f, 0.9826f, -0.2356f + }).reshape(3,4); + INDArray expected2 = Nd4j.scalar(DataType.FLOAT, -1.112316132); + String err = OpValidation.validate(new OpTestCase(new PowBp(preReduceInput_1, preReduceInput_2, + gradOutput, output1, output2)).expectedOutput(0, expected1).expectedOutput(1, expected2)); + assertNull(err); + } }