diff --git a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu index c899e0184..e492baf8e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu @@ -64,7 +64,11 @@ namespace helpers { const int threadsPerBlock = MAX_NUM_THREADS / 2; const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size(); - NDArray::prepareSpecialUse({outArr}, inArrs); + NDArray::prepareSpecialUse({outArr}, {}); + + // FIXME: !!! + for (auto v:inArrs) + NDArray::prepareSpecialUse({}, {v}); std::vector inputList(inArrs.size()); std::vector inputShapeList(inArrs.size()); @@ -88,8 +92,11 @@ namespace helpers { } manager.synchronize(); - NDArray::registerSpecialUse({outArr}, inArrs); + NDArray::registerSpecialUse({outArr}, {}); + // FIXME: !!! + for (auto v:inArrs) + NDArray::registerSpecialUse({}, {v}); } void stack(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray* outArr, const int dim) {