From 190575196cfe900d117d142068bbee66b406b42c Mon Sep 17 00:00:00 2001 From: shugeo Date: Tue, 3 Dec 2019 14:06:38 +0200 Subject: [PATCH] Refactored pad and mirror_pad ops to conform with TF. (#100) --- .../include/ops/declarable/generic/transforms/mirrorPad.cpp | 2 +- libnd4j/include/ops/declarable/generic/transforms/pad.cpp | 3 ++- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index 603bfdf61..fac8451a5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) { DECLARE_TYPES(mirror_pad) { getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index 31a5d25b3..9d410a6c3 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -78,7 +78,8 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { DECLARE_TYPES(pad) { getOpDescriptor() ->setAllowedInputTypes(0, nd4j::DataType::ANY) - ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes + ->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF +// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes ->setSameMode(true); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index f232411b2..23351f7af 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -4549,7 +4549,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test13) { TEST_F(DeclarableOpsTests7, mirrorPad_test14) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); @@ -4567,7 +4567,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test14) { TEST_F(DeclarableOpsTests7, mirrorPad_test15) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});