From fc7c6d4e826f36af16d0a81c6adbb38cae34acee Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 19 Dec 2019 12:10:06 +0200 Subject: [PATCH] Shugeo roll fix3 (#127) * Added tests for roll with scalar shift and axis. * Fixed problem with roll on 1D input with scalar axis and test. * Only cosmetic changes. --- .../declarable/generic/parity_ops/roll.cpp | 2 +- .../layers_tests/DeclarableOpsTests7.cpp | 38 +++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 61f592f1d..f93fc198e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -84,7 +84,7 @@ namespace ops { if (block.isInplace()) output = input; - shiftIsLinear = axes.size() == 0; + shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); if (shiftIsLinear) { helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 220191011..ffb847dbd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -3148,14 +3148,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); -// output->printShapeInfo("Output shape"); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result->at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -3240,10 +3233,6 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); - -// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, -// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32 -// 21.41, 21.42, 22.11, 22.12 // ---------------------------------------------------------------- nd4j::ops::roll op; @@ -3269,10 +3258,6 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); - -// 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, -// 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32 -// 21.41, 21.42, 22.11, 22.12 // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; @@ -3518,6 +3503,27 @@ TEST_F(DeclarableOpsTests7, TestRoll_14) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_15) { + auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f }); + auto shift = NDArrayFactory::create(2); + auto axis = NDArrayFactory::create(0); + + auto exp = NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); +// out->printIndexedBuffer("Output 15"); +// exp.printIndexedBuffer("Expect 15"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) {