diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index 3051de448..94675c587 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -29,104 +29,102 @@ #include #include -namespace nd4j { - namespace ops { - namespace helpers { +namespace nd4j { +namespace ops { +namespace helpers { + /////////////////////////////////////////////////////////////////// - template - __global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) { +template +__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { - __shared__ int arrIdx, blocksPerArr; + T* z = reinterpret_cast(vz); + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + __shared__ int rank; - if (threadIdx.x == 0) { + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil - arrIdx = blockIdx.x / blocksPerArr; - } - - __syncthreads(); - - for(int j = arrIdx; j < numOfArrs; j += gridDim.x) { - - const auto* x = reinterpret_cast(reinterpret_cast(pVx)[j]); - auto* z = reinterpret_cast(reinterpret_cast(pVz)[j]); - const auto* xShapeInfo = reinterpret_cast(pxShapeInfo)[j]; - const auto* zShapeInfo = reinterpret_cast(pzShapeInfo)[j]; - - const auto arrLen = shape::length(xShapeInfo); - - const auto arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil - - const auto start = (blockIdx.x % blocksPerArr) * arrLenPerBlock; - const auto end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock); - - for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x) - z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)]; - } - } - -/////////////////////////////////////////////////////////////////// - template - __host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) { - - concatCuda<<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo); - } - BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES); - - ////////////////////////////////////////////////////////////////////////// - void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - - const int numOfArrs = inArrs.size(); - for(int i = 0; i < numOfArrs; ++i) - if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice(); - - const int rank = inArrs[0]->rankOf(); - const int rank2 = 2*rank; - std::vector> indices(numOfArrs, std::vector(rank2,0)); - - // take into account indices for first array - indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); - - // loop through the rest of input arrays - for(int i = 1; i < numOfArrs; ++i) { - indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from - indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) - } - - std::vector outSubArrs(numOfArrs); - for(int i = 0; i < numOfArrs; ++i) - outSubArrs[i] = new NDArray(output(indices[i], true)); - - // prepare arrays of pointers on buffers and shapes - std::vector hOutBuffers(numOfArrs), hInBuffers(numOfArrs); - std::vector hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs); - for(int i = 0; i < numOfArrs; ++i) { - hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer(); - hInBuffers[i] = inArrs[i]->getSpecialBuffer(); - hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo(); - hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo(); - } - - // allocate and copy all buffers and shapes arrays to global memory - PointersManager manager(context, "helpers::concat"); - void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); - void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); - void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*)); - - BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES); - - manager.synchronize(); - - for(int i = 0; i < numOfArrs; ++i) - delete outSubArrs[i]; - - for(int i = 0; i < numOfArrs; ++i) - inArrs[i]->tickReadHost(); - - output.tickWriteDevice(); - } - } + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; } + + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + if(tid >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(rank, zShapeInfo + 1, tid, zLen, coords); + + const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + + int inArrIdx = 0; + Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; + + while(coords[axis] >= xShapeInfo[axis + 1]) { + coords[axis] -= xShapeInfo[axis + 1]; + xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + } + + const auto* x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank); + + z[zOffset] = x[xOffset]; +} + +/////////////////////////////////////////////////////////////////// +template +__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { + + concatCuda<<>>(pVx, pxShapeInfo, vz, zShapeInfo, axis); +} +BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128; + + const int numOfArrs = inArrs.size(); + + for(int i = 0; i < numOfArrs; ++i) + inArrs[i]->syncToDevice(); + + output.syncToDevice(); + + // prepare arrays of pointers on buffers and shapes + std::vector hInBuffers(numOfArrs); + std::vector hInShapeInfo(numOfArrs); + + for(int i = 0; i < numOfArrs; ++i) { + hInBuffers[i] = inArrs[i]->getSpecialBuffer(); + hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo(); + } + + PointersManager manager(context, "helpers::concat"); + + void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); + void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); + + BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES); + + manager.synchronize(); + + for(int i = 0; i < numOfArrs; ++i) + inArrs[i]->tickReadDevice(); + + output.tickWriteDevice(); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 6fe3dfac6..7fbc309d5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -750,59 +750,6 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) { } -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests12, concat_test10) { - - NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32); - NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32); - NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32); - - x0 = 0.; - x1 = 1.; - - nd4j::ops::concat op; - auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests12, concat_14) { - - NDArray x0('c', {1,6}, {1,2,3,4,5,6}); - NDArray x1('c', {1,6}, {7,8,9,10,11,12}); - NDArray output('f', {2,6}, nd4j::DataType::DOUBLE); - NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); - - nd4j::ops::concat op; - - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - // output.printBuffer(); - // output.printIndexedBuffer(); - - ASSERT_TRUE(exp.equalsTo(output)); -} - - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests12, concat_15) { - - NDArray x0('c', {1,4}, {1,2,3,4}); - NDArray x1('c', {1,4}, {5,6,7,8}); - NDArray output('c', {2,4}, nd4j::DataType::DOUBLE); - NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); - - nd4j::ops::concat op; - - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - // output.printBuffer(); - // output.printIndexedBuffer(); - - ASSERT_TRUE(exp.equalsTo(output)); -} - - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 06d677b27..df1421d71 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -364,77 +364,6 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { delete result; } -TEST_F(DeclarableOpsTests15, test_concat_column_1) { - auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); - auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto z = NDArrayFactory::create('c', {2, 2}); - - nd4j::ops::concat op; - auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); - ASSERT_EQ(Status::OK(), status); - - z.printIndexedBuffer("z"); - - ASSERT_EQ(e, z); -} - -TEST_F(DeclarableOpsTests15, test_concat_large_1) { - std::array arrays; - Context context(1); - Nd4jLong axis = 0; - - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < arrays.size(); e++) { - auto array = NDArrayFactory::create_('c', {1, 300}); - array->assign(e); - context.setInputArray(e, array, true); - } - - auto z = NDArrayFactory::create('c', {2000, 300}); - context.setOutputArray(0, &z, false); - context.setIArguments(&axis, 1); - - nd4j::ops::concat op; - op.execute(&context); - - for (int e = 0; e < arrays.size(); e++) { - auto row = z.tensorAlongDimension(e, {1}); - - ASSERT_NEAR((float) e, row->e(0), 1e-5f); - - delete row; - } -} - -TEST_F(DeclarableOpsTests15, test_concat_large_2) { - std::array arrays; - Context context(1); - Nd4jLong axis = 0; - - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < arrays.size(); e++) { - auto array = NDArrayFactory::create_('c', {1, 5, 20}); - array->assign(e); - context.setInputArray(e, array, true); - } - - auto z = NDArrayFactory::create('c', {arrays.size(), 5, 20}); - context.setOutputArray(0, &z, false); - context.setIArguments(&axis, 1); - - nd4j::ops::concat op; - op.execute(&context); - - for (int e = 0; e < arrays.size(); e++) { - auto row = z.tensorAlongDimension(e, {1, 2}); - - ASSERT_NEAR((float) e, row->meanNumber().e(0), 1e-5f); - - delete row; - } -} - TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index d02ddcb69..62172dbf2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -373,35 +373,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { delete result; } -TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) { - auto x0 = NDArrayFactory::create('c', {1, 100, 150}); - auto x1 = NDArrayFactory::create('c', {1, 100, 150}); - auto x2 = NDArrayFactory::create('c', {1, 100, 150}); - auto x3 = NDArrayFactory::create('c', {1, 100, 150}); - - x0.assign(1.0); - x1.assign(2.0); - x2.assign(3.0); - x3.assign(4.0); - - nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); - ASSERT_TRUE(4 == numOfTads); - - for (int e = 0; e < numOfTads; e++) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((double) e+1, mean, 1e-5); - } - - delete result; -} - TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) { auto A = NDArrayFactory::create('c', {3, 3}); auto B = NDArrayFactory::create('c', {3, 1}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 5568e5119..48996f2a5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -2845,31 +2845,4 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { delete result; } -TEST_F(DeclarableOpsTests6, concat_test14) { - - NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE); - NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE); - - x0 = 1.; - x1 = 2.; - - nd4j::ops::concat op; - auto result = op.execute({&x0, &x1}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - // z->printShapeInfo(); - // z->printIndexedBuffer(); - - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); - ASSERT_TRUE(2 == numOfTads); - - for (int e = 0; e < numOfTads; ++e) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((e+1)*1., mean, 1e-5); - } - - delete result; -} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 0d9ea7c24..d27aa4e46 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -584,6 +584,180 @@ TEST_F(DeclarableOpsTests9, concat_test16) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test17) { + + NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE); + NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE); + + x0 = 1.; + x1 = 2.; + + nd4j::ops::concat op; + auto result = op.execute({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + // z->printShapeInfo(); + // z->printIndexedBuffer(); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); + + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e+1)*1., mean, 1e-5); + } + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test18) { + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 300}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) { + auto row = z.tensorAlongDimension(e, {1}); + + ASSERT_NEAR((float) e, row->e(0), 1e-5f); + + delete row; + } +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test19) { + + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 5, 20}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {arrays.size(), 5, 20}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) + ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e(0), 1e-5f); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test20) { + auto x0 = NDArrayFactory::create('c', {1, 100, 150}); + auto x1 = NDArrayFactory::create('c', {1, 100, 150}); + auto x2 = NDArrayFactory::create('c', {1, 100, 150}); + auto x3 = NDArrayFactory::create('c', {1, 100, 150}); + + x0.assign(1.0); + x1.assign(2.0); + x2.assign(3.0); + x3.assign(4.0); + + nd4j::ops::concat op; + auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + ASSERT_TRUE(4 == numOfTads); + + for (int e = 0; e < numOfTads; e++) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((double) e+1, mean, 1e-5); + } + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test21) { + + NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32); + NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32); + NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32); + + x0 = 0.; + x1 = 1.; + + nd4j::ops::concat op; + auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test22) { + + NDArray x0('c', {1,6}, {1,2,3,4,5,6}); + NDArray x1('c', {1,6}, {7,8,9,10,11,12}); + NDArray output('f', {2,6}, nd4j::DataType::DOUBLE); + NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); + + nd4j::ops::concat op; + + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test23) { + + NDArray x0('c', {1,4}, {1,2,3,4}); + NDArray x1('c', {1,4}, {5,6,7,8}); + NDArray output('c', {2,4}, nd4j::DataType::DOUBLE); + NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); + + nd4j::ops::concat op; + + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test24) { + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); + + nd4j::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test1) {