diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 21eedc665..10847f882 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -1422,6 +1422,7 @@ namespace nd4j { template void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value); + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, NDArray const& value); template diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 0f0621a80..1d810b421 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -4187,6 +4187,24 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) { NDArray::registerPrimaryUse({this}, {&scalar}); } +//////////////////////////////////////////////////////////////////////// + void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) { + + if(!scalar.isScalar()) + throw std::invalid_argument("NDArray::p method: input array must be scalar!"); + if (i >= _length) + throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); + +// void *p = reinterpret_cast(scalar.getBuffer()); + Nd4jLong coords[4] = {i, j, k, l}; + auto xOffset = shape::getOffset(getShapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {&scalar}, true); +// BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (this->getBuffer(), xOffset, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&scalar}); + } + ////////////////////////////////////////////////////////////////////////// void NDArray::addRowVector(const NDArray *row, NDArray *target) const { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp index 6f32c1cbd..1d1e32b52 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp @@ -25,7 +25,53 @@ namespace ops { namespace helpers { void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { + // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set + // boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as + // floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd + // floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd + // height = images->sizeAt(1), width = images->sizeAt(2) + // colors - colors for each box given + // set up color for each box as frame + auto height = images->sizeAt(1); + auto width = images->sizeAt(2); + auto channels = images->sizeAt(3); + auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch + auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch + output->assign(images); + for (auto b = 0; b < imageList->size(); ++b) { // loop by batch +// auto image = imageList->at(b); + auto box = boxList->at(b); + + auto internalBoxes = box->allTensorsAlongDimension({1}); + auto colorSet = colors->allTensorsAlongDimension({1}); + + for (auto c = 0; c < colorSet->size(); ++c) { + // box with shape + auto internalBox = internalBoxes->at(c); + auto color = colorSet->at(c); + auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox->e(0))); + auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox->e(2))); + auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox->e(1))); + auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox->e(3))); + for (auto y = rowStart; y <= rowEnd; y++) { + for (auto e = 0; e < color->lengthOf(); ++e) { + output->p(b, y, colStart, e, color->e(e)); + output->p(b, y, colEnd, e, color->e(e)); + } + } + for (auto x = colStart + 1; x < colEnd; x++) { + for (auto e = 0; e < color->lengthOf(); ++e) { + output->p(b, rowStart, x, e, color->e(e)); + output->p(b, rowEnd, x, e, color->e(e)); + } + } + } + delete colorSet; + delete internalBoxes; + } + delete imageList; + delete boxList; } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index ce327f0cf..c61cda29f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2051,7 +2051,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { 0.3, 0.3, 0.7, 0.7, 0.4, 0.4, 0.6, 0.6 }); - NDArray colors = NDArrayFactory::create('c', {2, 3}, {201., 202., 203., 128., 129., 130.}); + NDArray colors = NDArrayFactory::create('c', {2, 3}, {201., 202., 203., 127., 128., 129.}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); NDArray expected = NDArrayFactory::create('c', {2,4,5,3}, { @@ -2072,7 +2072,39 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - result->printIndexedBuffer("Bounded boxes"); + result->printBuffer("Bounded boxes"); + expected.printBuffer("Bounded expec"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { + NDArray images = NDArrayFactory::create('c', {1,9,9,1}); + NDArray boxes = NDArrayFactory::create('c', {1, 1, 4}, {0.2, 0.2, 0.7, 0.7}); + NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1,9,9,1}, { + 1.1 , 2.1, 3.1 , 4.1 , 5.1 , 6.1 , 7.1 , 8.1 , 9.1 , + 10.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 16.1 , 17.1 , 18.1 , + 19.1 , 0.95, 21.1, 22.1, 23.1, 0.95, 25.1 , 26.1 , 27.1 , + 28.1 , 0.95, 30.1, 31.1, 32.1, 0.95, 34.1 , 35.1 , 36.1 , + 37.1 , 0.95, 39.1, 40.1, 41.1, 0.95, 43.1 , 44.1 , 45.1 , + 46.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 52.1 , 53.1 , 54.1 , + 55.1 , 56.1, 57.1 , 58.1 , 59.1 , 60.1 , 61.1 , 62.1 , 63.1 , + 64.1 , 65.1, 66.1 , 67.1 , 68.1 , 69.1 , 70.1 , 71.1 , 72.1 , + 73.1 , 74.1, 75.1 , 76.1 , 77.1 , 78.1 , 79.1 , 80.1 , 81.1 }); + images.linspace(1.1); + nd4j::ops::draw_bounding_boxes op; + auto results = op.execute({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + result->printIndexedBuffer("Bounded boxes 2"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result));