diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index daa9d9328..f89ee6e1d 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -46,26 +46,53 @@ namespace nd4j { + template ::value>::type> + ND4J_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); - ND4J_EXPORT NDArray operator-(const float&, const NDArray&); - ND4J_EXPORT NDArray operator-(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator-(const double&, const NDArray&); - ND4J_EXPORT NDArray operator-(const int&, const NDArray&); + template ::value>::type> + ND4J_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); + + template ::value>::type> + ND4J_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); + + template ::value>::type> + ND4J_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); + template ::value>::type> + ND4J_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); + template ::value>::type> + ND4J_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); + + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); + template ::type>::value && std::is_same::type>::value>::type> + ND4J_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); - ND4J_EXPORT NDArray operator+(const float&, const NDArray&); - ND4J_EXPORT NDArray operator+(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator+(const double&, const NDArray&); - ND4J_EXPORT NDArray operator+(const int&, const NDArray&); - ND4J_EXPORT NDArray operator*(const float&, const NDArray&); - ND4J_EXPORT NDArray operator*(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator*(const double&, const NDArray&); - ND4J_EXPORT NDArray operator*(const int&, const NDArray&); - ND4J_EXPORT NDArray operator/(const float&, const NDArray&); - ND4J_EXPORT NDArray operator/(const float16&, const NDArray&); - ND4J_EXPORT NDArray operator/(const double&, const NDArray&); - ND4J_EXPORT NDArray operator/(const int&, const NDArray&); ND4J_EXPORT NDArray mmul(const NDArray&, const NDArray&); @@ -323,7 +350,7 @@ namespace nd4j { * axis - axis along which to repeat elements * repeats - number of repetitions */ - NDArray* repeat(const int axis, const std::vector& repeats) const; + NDArray repeat(const int axis, const std::vector& repeats) const; /** * This method fills this array with zeros @@ -336,15 +363,7 @@ namespace nd4j { * @param array * @return */ - static NDArray quantize(NDArray &array); - - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ - static NDArray* quantize(NDArray *array); + static NDArray quantize(const NDArray &array); /** * fill target array by repeating current array @@ -356,19 +375,16 @@ namespace nd4j { /** * creates array which points on certain sub-range of this array, sub-range is defined by given indices */ - NDArray* subarray(IndicesList& indices) const; - NDArray* subarray(const std::initializer_list& idx) const; - NDArray* subarray(const Intervals& idx) const; + NDArray subarray(IndicesList& indices) const; + NDArray subarray(const std::initializer_list& idx) const; + NDArray subarray(const Intervals& idx) const; /** * cast array elements to given dtype */ - template - NDArray* cast(); + NDArray cast(DataType dtype) const; - NDArray* cast(DataType dtype) const; - - void cast(NDArray* target, DataType dtype); + void cast(NDArray& target, DataType dtype); /** * returns _context @@ -455,16 +471,22 @@ namespace nd4j { /** * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array */ - NDArray permute(const std::initializer_list& dimensions) const; - NDArray permute(const std::vector& dimensions) const; - NDArray permute(const int* dimensions, const int rank) const; + NDArray permute(const std::initializer_list& dimensions) const &; + NDArray permute(const std::vector& dimensions) const &; + NDArray permute(const int* dimensions, const int rank) const &; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const int* dimensions, const int rank) &&; void permute(const int* dimensions, const int rank, NDArray& target) const; void permute(const std::vector& dimensions, NDArray& target) const; - NDArray permute(const std::initializer_list& dimensions) const; - NDArray permute(const std::vector& dimensions) const; - NDArray permute(const Nd4jLong* dimensions, const int rank) const; + NDArray permute(const std::initializer_list& dimensions) const &; + NDArray permute(const std::vector& dimensions) const &; + NDArray permute(const Nd4jLong* dimensions, const int rank) const &; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const Nd4jLong* dimensions, const int rank) &&; void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const; void permute(const std::vector& dimensions, NDArray& target) const; @@ -522,24 +544,13 @@ namespace nd4j { /** * this method assigns given value to all elements in array */ - void assign(const double value, bool allowParallelism = true); - void assign(const float value, bool allowParallelism = true); - void assign(const float16 value, bool allowParallelism = true); - void assign(const bfloat16& value, bool allowParallelism = true); - void assign(const Nd4jLong value, bool allowParallelism = true); - void assign(const int value, bool allowParallelism = true); - void assign(const int16_t value, bool allowParallelism = true); - void assign(const uint8_t value, bool allowParallelism = true); - void assign(const uint16_t value, bool allowParallelism = true); - void assign(const uint32_t value, bool allowParallelism = true); - void assign(const uint64_t value, bool allowParallelism = true); - void assign(const int8_t value, bool allowParallelism = true); - void assign(const bool value, bool allowParallelism = true); + template ::value>::type> + void assign(const T& value, bool allowParallelism = true); /** * returns new copy of this array, optionally in different order */ - NDArray *dup(const char newOrder = 'a') const; + NDArray dup(const char newOrder = 'a') const; /** * returns sum of all elements of array @@ -566,21 +577,17 @@ namespace nd4j { * keepDims - if true then put unities in place of reduced dimensions */ - NDArray* reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray* reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDims(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; /** * method reduces array by excluding its shapes along dimensions present in given dimensions vector @@ -589,10 +596,10 @@ namespace nd4j { * keepDims - if true then put unities in place of reduced dimensions * extras - extra parameters */ - void reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(nd4j::reduce::LongOps op, NDArray* target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; + void reduceAlongDimension(nd4j::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; /** * return variance of array elements set @@ -631,20 +638,24 @@ namespace nd4j { void makeBothActual() const { syncToDevice(); syncToHost(); } - void applyTransform(nd4j::transform::FloatOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::SameOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::AnyOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::BoolOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); - void applyTransform(nd4j::transform::StrictOps op, NDArray *target = nullptr, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::SameOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams = nullptr); + void applyTransform(nd4j::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams = nullptr); /** * apply OpName transformation to this array and store result in new array to be returned * extraParams - extra parameters for operation */ - NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) const; - NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) const; - NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) const; - NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) const; + NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) const &; + NDArray transform(nd4j::transform::FloatOps op, void *extraParams = nullptr) &&; + NDArray transform(nd4j::transform::SameOps op, void *extraParams = nullptr) &&; + NDArray transform(nd4j::transform::BoolOps op, void *extraParams = nullptr) &&; + NDArray transform(nd4j::transform::StrictOps op, void *extraParams = nullptr) &&; /** * apply pairwise OpName transformation based on "this" and "other" arras elements, store result in this array @@ -659,11 +670,11 @@ namespace nd4j { * target - where to store result * extraParams - extra parameters for operation */ - void applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray& other, NDArray&target, ExtraArguments *extraParams = nullptr) const; /** * apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this) @@ -672,23 +683,23 @@ namespace nd4j { * target - where to store result * extraParams - extra parameters for operation */ - void applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray* tad, NDArray* target = nullptr, ExtraArguments* extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tad, NDArray& target, ExtraArguments* extraArgs = nullptr); - void applyBroadcast(nd4j::broadcast::Ops op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::Ops op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); - - void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector &dimensions, const NDArray& tad, NDArray &target, ExtraArguments *extraArgs = nullptr); /** * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting * other - input array * extraParams - extra parameters for operation */ - NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const; - - NDArray* applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, ExtraArguments *extraArgs = nullptr) const; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const &; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) const &; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) &&; + NDArray applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) &&; /** * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting @@ -697,11 +708,11 @@ namespace nd4j { * checkTargetShape - if true check whether target shape is suitable for broadcasting * extraParams - extra parameters for operation */ - void applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; /** @@ -711,13 +722,13 @@ namespace nd4j { * extraParams - extra parameters for operation */ template - void applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray* target = nullptr, ExtraArguments *extraParams = nullptr); + void applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr); template - void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; template - void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; /** * apply a scalar operation to an array @@ -725,27 +736,27 @@ namespace nd4j { * target - where to store result * extraParams - extra parameters for operation */ - void applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target = nullptr, ExtraArguments *extraParams = nullptr); + void applyScalarArr(nd4j::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - void applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalarArr(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; #if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) template - FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyLambda(Lambda func, NDArray& target); template - FORCEINLINE void applyPairwiseLambda(const NDArray* other, Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target); template - FORCEINLINE void applyIndexedLambda(Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyIndexedLambda(Lambda func, NDArray& target); template - FORCEINLINE void applyIndexedPairwiseLambda(NDArray* other, Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target); template - FORCEINLINE void applyTriplewiseLambda(NDArray* second, NDArray *third, Lambda func, NDArray* target = nullptr); + FORCEINLINE void applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target); #else /** @@ -754,7 +765,7 @@ namespace nd4j { * target - where to store result */ template - void applyLambda(const std::function& func, NDArray* target = nullptr); + void applyLambda(const std::function& func, NDArray& target); /** * apply pairwise operation "func" to an array @@ -763,16 +774,16 @@ namespace nd4j { * target - where to store result */ template - void applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target = nullptr); + void applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); template - void applyIndexedLambda(const std::function& func, NDArray* target = nullptr); + void applyIndexedLambda(const std::function& func, NDArray& target); template - void applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target = nullptr); + void applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); template - void applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target = nullptr); + void applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target); #endif /** @@ -780,7 +791,7 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along * extraArgs - extra parameters for operation */ - NDArray* applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; + NDArray applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; /** * reduces dimensions in array relying on index operation OpName @@ -788,14 +799,14 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along * extraArgs - extra parameters for operation */ - void applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; + void applyIndexReduce(nd4j::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; /** * apply reduce3 operation OpName to this and other array, return result in new output array * other - input array * extraArgs - extra parameters for operation */ - NDArray* applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const ExtraArguments* extraParams = nullptr) const; + NDArray applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams = nullptr) const; /** * apply reduce3 operation OpName to this and other array, return result in new output array @@ -803,7 +814,7 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along (tads not axis) * extraArgs - extra parameters for operation */ - NDArray* applyAllReduce3(nd4j::reduce3::Ops op, const NDArray* other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; + NDArray applyAllReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; /** * apply reduce3 (exec) operation OpName to this and other array, return result in new output array @@ -811,30 +822,26 @@ namespace nd4j { * dimensions - vector of dimensions to reduce along (same as reduceAlongDimension) * extraArgs - extra parameters for operation */ - NDArray* applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; - + NDArray applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; /** * returns variance along given dimensions * biasCorrected - if true bias correction will be applied * dimensions - vector of dimensions to calculate variance along */ - NDArray* varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; - NDArray* varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; + NDArray varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; + NDArray varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; - NDArray varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; - NDArray varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; - - void varianceAlongDimension(nd4j::variance::Ops op, NDArray* target, const bool biasCorrected, const std::vector& dimensions) const; - - void varianceAlongDimension(nd4j::variance::Ops op, NDArray* target, const bool biasCorrected, const std::initializer_list& dimensions) const; + void varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const; + void varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::initializer_list& dimensions) const; #endif /** * apply transpose operation to the copy of this array, that is this array remains unaffected */ - NDArray transpose() const; + NDArray transpose() const &; + NDArray transpose() &&; /** * perform transpose operation and store result in target, this array remains unaffected @@ -852,8 +859,8 @@ namespace nd4j { * index - the number of array to be returned among set of possible arrays * dimensions - array of dimensions to point on */ - NDArray* tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const; - NDArray* tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const; + NDArray tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const; + NDArray tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const; /** * returns the number of arrays pointing on specified dimension(s) @@ -874,54 +881,54 @@ namespace nd4j { * add given row vector to all rows of this array * row - row vector to add */ - void addiRowVector(const NDArray *row); + void addiRowVector(const NDArray& row); /** * add given row vector to all rows of this array, store result in target * row - row vector to add * target - where to store result */ - void addRowVector(const NDArray *row, NDArray* target) const; + void addRowVector(const NDArray& row, NDArray& target) const; /** * subtract given row vector from all rows of this array, store result in target * row - row vector to subtract * target - where to store result */ - void subRowVector(const NDArray *row, NDArray* target) const; + void subRowVector(const NDArray& row, NDArray& target) const; /** * multiply all rows of this array on given row vector, store result in target * row - row vector to multiply on * target - where to store result */ - void mulRowVector(const NDArray *row, NDArray* target) const; + void mulRowVector(const NDArray &row, NDArray& target) const; /** * divide all rows of this array on given row vector, store result in target * row - row vector to divide on * target - where to store result */ - void divRowVector(const NDArray *row, NDArray* target) const; + void divRowVector(const NDArray &row, NDArray& target) const; /** * add given column vector to all columns of this array, store result in target * column - column vector to add * target - where to store result */ - void addColumnVector(const NDArray *column, NDArray* target) const; + void addColumnVector(const NDArray &column, NDArray& target) const; /** * add given column vector to all columns of this array, this array becomes affected (in-place operation) * column - column vector to add */ - void addiColumnVector(const NDArray *column); + void addiColumnVector(const NDArray &column); /** * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) * column - column vector to multiply on */ - void muliColumnVector(const NDArray *column); + void muliColumnVector(const NDArray &column); /** * returns number of bytes used by _buffer & _shapeInfo @@ -958,7 +965,8 @@ namespace nd4j { * * if permute have been applied before or there are weird strides, then new buffer is allocated for new array */ - NDArray reshape(const char order, const std::vector& shape) const; + NDArray reshape(const char order, const std::vector& shape) const &; + NDArray reshape(const char order, const std::vector& shape) &&; /** * calculate strides and set given order @@ -991,12 +999,6 @@ namespace nd4j { */ void tile(NDArray& target) const; - /** - * returns an array which is result of broadcasting of this and other arrays - * other - input array - */ - NDArray* broadcast(const NDArray& other); - /** * check whether array is identity matrix */ @@ -1007,7 +1009,6 @@ namespace nd4j { */ bool isUnitary(); - /** * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) @@ -1038,27 +1039,6 @@ namespace nd4j { */ void getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape = false) const; - /** - * addition operator: array + other - * other - input array to add - */ - NDArray operator+(const NDArray& other) const; - - /** - * addition operator: array + scalar - * scalar - input scalar to add - */ - template - NDArray operator+(const T& scalar) const; - - /** - * friend functions which implement addition operator: scalar + array - * scalar - input scalar to add - */ - //template - //friend NDArray nd4j::operator+(const T scalar, const NDArray& arr); - - /** * addition unary operator array += other * other - input array to add @@ -1077,42 +1057,11 @@ namespace nd4j { template void operator-=(const T other); - /** - * subtraction operator: array - other - * other - input array to subtract - */ - NDArray operator-(const NDArray& other) const; - - /** - * subtraction operator: array - scalar - * scalar - input scalar to subtract - */ - template - NDArray operator-(const T& scalar) const; - /** * negative operator, it changes sign of all array elements on opposite */ - NDArray operator-() const; - - /** - * friend functions which implement subtraction operator: scalar - array - * scalar - input scalar to subtract - */ - //friend NDArray nd4j::operator-(const float scalar, const NDArray& arr); - - /** - * pairwise multiplication operator: array * other - * other - input array to multiply on - */ - NDArray operator*(const NDArray& other) const; - - /** - * multiplication operator: array * scalar - * scalar - input scalar to multiply on - */ - template - NDArray operator*(const T& scalar) const; + NDArray operator-() const &; + NDArray operator-() &&; /** * pairwise multiplication unary operator array *= other @@ -1127,19 +1076,6 @@ namespace nd4j { template void operator*=(const T scalar); - /** - * pairwise division operator: array / other - * other - input array to divide on - */ - NDArray operator/(const NDArray& other) const; - - /** - * division operator: array / scalar - * scalar - input scalar to divide each array element on - */ - template - NDArray operator/(const T& scalar) const; - /** * pairwise division unary operator: array /= other * other - input array to divide on @@ -1180,7 +1116,7 @@ namespace nd4j { * return vector with buffer which points on corresponding diagonal elements of array * type - means of vector to be returned: column ('c') or row ('r') */ - NDArray* diagonal(const char type ) const; + NDArray diagonal(const char type ) const; /** * fill target matrix with given value in one or two directions from main diagonal: @@ -1194,7 +1130,7 @@ namespace nd4j { * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) */ template - void fillAsTriangular(const float value, int lower, int upper, const char direction = 'b', NDArray* target = nullptr); + void fillAsTriangular(const float value, int lower, int upper, NDArray& target, const char direction = 'b'); /** * change an array by repeating it the number of times in order to acquire new shape equal to the input shape @@ -1203,15 +1139,15 @@ namespace nd4j { * target - optional argument, if target != nullptr the resulting array will be placed in target, in opposite case tile operation is done in place */ NDArray tileToShape(const Nd4jLong* shapeInfo); - void tileToShape(const std::vector& shape, NDArray* target = nullptr); + void tileToShape(const std::vector& shape, NDArray& target); #ifndef __JAVACPP_HACK__ - void tileToShape(const std::initializer_list& shape, NDArray* target = nullptr); + void tileToShape(const std::initializer_list& shape, NDArray& target); #endif template - NDArray* asT() const; + NDArray asT() const; - NDArray* asT(DataType dtype) const; + NDArray asT(DataType dtype) const; void linspace(const double start); @@ -1223,15 +1159,13 @@ namespace nd4j { */ double getTrace() const; - ResultSet* multipleTensorsAlongDimension(const std::vector& indices, const std::vector& dimensions) const; + ResultSet multipleTensorsAlongDimension(const std::vector& indices, const std::vector& dimensions) const; - ResultSet* allTensorsAlongDimension(const std::initializer_list& dimensions) const; + ResultSet allTensorsAlongDimension(const std::initializer_list& dimensions) const; - ResultSet* allTensorsAlongDimension(const std::vector& dimensions) const; + ResultSet allTensorsAlongDimension(const std::vector& dimensions) const; - //ResultSet allTensorsAlongDims(const std::vector& dimensions) const; - - ResultSet* allExamples()const ; + ResultSet allExamples()const ; /** * set _shapeInfo @@ -1356,7 +1290,7 @@ namespace nd4j { /** * returns true if these two NDArrays have same rank, dimensions, strides, ews and order */ - FORCEINLINE bool isSameShapeStrict(const NDArray *other) const; + FORCEINLINE bool isSameShapeStrict(const NDArray& other) const; /** * returns true if buffer && shapeInfo were defined (non nullptr) @@ -1439,11 +1373,6 @@ namespace nd4j { template void pIdx(const Nd4jLong* indices, const T value); - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - NDArray* subarray(IndicesList& indices, std::vector& strides) const; - /** * returns true if array is 2D */ @@ -1512,64 +1441,9 @@ namespace nd4j { */ bool isS() const; - /** - * inline accessing operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i) const; - - /** - * inline modifying operator for matrix, i - absolute index - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i); - - /** - * inline accessing operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j) const; - - /** - * inline modifying operator for 2D array, i - row, j - column - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j); - - /** - * inline accessing operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - - /** - * inline modifying operator for 3D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); - - /** - * inline modifying operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w); - - /** - * inline accessing operator for 4D array, i - height, j - width, k - depth - */ - //FORCEINLINE NDArray operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const; - - /** - * inline modifying operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray& operator()(const Nd4jLong* idx); - - /** - * inline accessing operator for ND array - * idx - array with corresponding indexes, for example {2,10,0,5,...,8}, number of indexes should be equal to array rank - */ - //FORCEINLINE NDArray operator()(const Nd4jLong* idx) const; - - - template std::vector asVectorT(); - FORCEINLINE bool isAttached(); NDArray* detach(); @@ -1585,394 +1459,201 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// - bool NDArray::isAttached() { - return this->_context->getWorkspace() != nullptr; - } +bool NDArray::isAttached() { + return this->_context->getWorkspace() != nullptr; +} - template - FORCEINLINE R NDArray::templatedGet(void *buffer, Nd4jLong index) const { - auto b = reinterpret_cast(buffer); - auto v = static_cast(b[index]); - return v; - } +template +FORCEINLINE R NDArray::templatedGet(void *buffer, Nd4jLong index) const { + auto b = reinterpret_cast(buffer); + auto v = static_cast(b[index]); + return v; +} - ////////////////////////////////////////////////////////////////////////// - void NDArray::setShapeInfo(Nd4jLong *shapeInfo) { - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primaryAsT(); - _shapeInfoD = buffer.specialAsT(); +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(Nd4jLong *shapeInfo) { + auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primaryAsT(); + _shapeInfoD = buffer.specialAsT(); - if (shapeInfo != nullptr) { - _dataType = ArrayOptions::dataType(_shapeInfo); - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = nd4j::DataType::INHERIT; + if (shapeInfo != nullptr) { + _dataType = ArrayOptions::dataType(_shapeInfo); + if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; - } + else + _length = shape::length(_shapeInfo); } + else { + _dataType = nd4j::DataType::INHERIT; + _length = 0; + } +} - ////////////////////////////////////////////////////////////////////////// - void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype) { - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primaryAsT(); - _shapeInfoD = buffer.specialAsT(); +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const nd4j::DataType dtype) { + auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primaryAsT(); + _shapeInfoD = buffer.specialAsT(); - if (shapeInfo != nullptr) { - _dataType = dtype; - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = nd4j::DataType::INHERIT; + if (shapeInfo != nullptr) { + _dataType = dtype; + if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; - } + else + _length = shape::length(_shapeInfo); } - - ////////////////////////////////////////////////////////////////////////// - char NDArray::ordering() const { - return shape::order(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isView() const { - return _isView; - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong* NDArray::shapeOf() const { - return shape::shapeOf(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong* NDArray::stridesOf() const { - return shape::stride(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - int NDArray::rankOf() const { - return shape::rank(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::lengthOf() const { - return _length; - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::rows() const { - if (this->rankOf() == 1) - return 1; - - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have rows"); - - return shapeOf()[0]; - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::columns() const { - if (this->rankOf() == 1) - return this->lengthOf(); - - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have columns"); - - return shapeOf()[1]; - } - - ////////////////////////////////////////////////////////////////////////// - - size_t NDArray::sizeOfT() const { - return DataTypeUtils::sizeOfElement(_dataType); - } - - ////////////////////////////////////////////////////////////////////////// - Nd4jLong NDArray::ews() const { - if (this->isEmpty() || this->rankOf() == 0) - return 1; - - return shape::elementWiseStride(_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::nonNull() const { - if (isEmpty()) - return true; - - if(!Environment::getInstance()->isCPU()) - return getDataBuffer()->special() != nullptr && getSpecialShapeInfo() != nullptr; - - return getDataBuffer()->primary() != nullptr && getShapeInfo() != nullptr; - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isMatrix() const { - if (isEmpty()) - return false; - - return 0 != shape::isMatrix(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isVector() const { - if (isEmpty()) - return false; - if (rankOf() == 1) - return true; - return !isScalar() && shape::isVector(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isColumnVector() const { - if (isEmpty()) - return false; - - return !isScalar() && shape::isColumnVector(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isRowVector() const { - if (isEmpty()) - return false; - - // 1D edge case - if (shape::rank(this->_shapeInfo) == 1) - return true; - - return !isScalar() && shape::isRowVector(this->_shapeInfo); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isCommonVector(int& posOfNonUnityDim) const { - - return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); - } - - ////////////////////////////////////////////////////////////////////////// - bool NDArray::isScalar() const { - return 0 != shape::isScalar(this->_shapeInfo); + else { + _dataType = nd4j::DataType::INHERIT; + _length = 0; } +} ////////////////////////////////////////////////////////////////////////// -// accessing operator for matrix, i - absolute index -/* -NDArray NDArray::operator()(const Nd4jLong i) const { - - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - char order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } +char NDArray::ordering() const { + return shape::order(_shapeInfo); } -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for matrix, i - absolute index -/* -NDArray& NDArray::operator()(const Nd4jLong i) { - if (i >= shape::length(_shapeInfo)) - throw std::invalid_argument("NDArray::operator(i): input index is out of array length !"); - - auto ews = shape::elementWiseStride(_shapeInfo); - auto order = ordering(); - - if(ews == 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME: bad - return result; - } else if(ews > 1 && order == 'c') { - auto cast = reinterpret_cast(_buffer) + (i * ews * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } else { - Nd4jLong idx[MAX_RANK]; - shape::ind2subC(rankOf(), shapeOf(), i, idx); - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; - } -}*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 2D matrix, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const { - - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // TODO: do we really want a view here? - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; +bool NDArray::isView() const { + return _isView; } -*/ -////////////////////////////////////////////////////////////////////////// -// modifying operator for 2D matrix, i - row, j - column -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); - - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -// accessing operator for 3D array, i - row, j - column -/* -NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || j >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; +Nd4jLong* NDArray::shapeOf() const { + return shape::shapeOf(_shapeInfo); } -*/ ////////////////////////////////////////////////////////////////////////// -// modifying operator for 3D array -/* -NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { - - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - //FIXME: bad, will crash! - return result; +Nd4jLong* NDArray::stridesOf() const { + return shape::stride(_shapeInfo); } -*/ -/* -NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) const { - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ -/* -NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v, const Nd4jLong w) { - - if (rankOf() != 4 || t >= shapeOf()[0] || u >= shapeOf()[1] || v >= shapeOf()[2] || w >= shapeOf()[3]) - throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); - - Nd4jLong coords[4] = {t, u, v, w}; - auto xOffset = shape::getOffset(getShapeInfo(), coords); - - // FIXME - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; -} -*/ ////////////////////////////////////////////////////////////////////////// -/* -NDArray NDArray::operator()(const Nd4jLong* idx) const { - - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - return result; +int NDArray::rankOf() const { + return shape::rank(_shapeInfo); } -*/ + ////////////////////////////////////////////////////////////////////////// -/* -NDArray& NDArray::operator()(const Nd4jLong* idx) { - - for(int i = 0; i < rankOf(); ++i) - if (idx[i] >= sizeAt(i)) - throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); - - auto xOffset = shape::getOffset(getShapeInfo(), idx); - - auto cast = reinterpret_cast(_buffer) + (xOffset * this->sizeOfT()); - NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); - - // FIXME - return result; +Nd4jLong NDArray::lengthOf() const { + return _length; } -*/ +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::rows() const { + if (this->rankOf() == 1) + return 1; - ////////////////////////////////////////////////////////////////////////// - Nd4jLong FORCEINLINE NDArray::memoryFootprint() { - Nd4jLong size = this->lengthOf() * this->sizeOfT(); - size += shape::shapeInfoByteLength(this->rankOf()); - return size; - } + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have rows"); - ////////////////////////////////////////////////////////////////////////// - // still the definition of inline function must be in header file - bool NDArray::isSameShape(const std::vector& shape) const{ - if (this->isScalar() && shape.size() == 1 && shape[0] == 0) - return true; - if (this->rankOf() != (int) shape.size()) - return false; - for (int e = 0; e < this->rankOf(); e++) { - if (this->shapeOf()[e] != shape.at(e) && shape.at(e) != -1) - return false; - } + return shapeOf()[0]; +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::columns() const { + if (this->rankOf() == 1) + return this->lengthOf(); + + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have columns"); + + return shapeOf()[1]; +} + +////////////////////////////////////////////////////////////////////////// + +size_t NDArray::sizeOfT() const { + return DataTypeUtils::sizeOfElement(_dataType); +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::ews() const { + if (this->isEmpty() || this->rankOf() == 0) + return 1; + + return shape::elementWiseStride(_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::nonNull() const { + if (isEmpty()) return true; + + if(!Environment::getInstance()->isCPU()) + return getDataBuffer()->special() != nullptr && getSpecialShapeInfo() != nullptr; + + return getDataBuffer()->primary() != nullptr && getShapeInfo() != nullptr; +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isMatrix() const { + if (isEmpty()) + return false; + + return 0 != shape::isMatrix(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isVector() const { + if (isEmpty()) + return false; + if (rankOf() == 1) + return true; + return !isScalar() && shape::isVector(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isColumnVector() const { + if (isEmpty()) + return false; + + return !isScalar() && shape::isColumnVector(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isRowVector() const { + if (isEmpty()) + return false; + + // 1D edge case + if (shape::rank(this->_shapeInfo) == 1) + return true; + + return !isScalar() && shape::isRowVector(this->_shapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isCommonVector(int& posOfNonUnityDim) const { + + return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); +} + +////////////////////////////////////////////////////////////////////////// +bool NDArray::isScalar() const { + return 0 != shape::isScalar(this->_shapeInfo); +} + + +////////////////////////////////////////////////////////////////////////// +Nd4jLong FORCEINLINE NDArray::memoryFootprint() { + Nd4jLong size = this->lengthOf() * this->sizeOfT(); + size += shape::shapeInfoByteLength(this->rankOf()); + return size; +} + +////////////////////////////////////////////////////////////////////////// +// still the definition of inline function must be in header file +bool NDArray::isSameShape(const std::vector& shape) const{ + if (this->isScalar() && shape.size() == 1 && shape[0] == 0) + return true; + if (this->rankOf() != (int) shape.size()) + return false; + for (int e = 0; e < this->rankOf(); e++) { + if (this->shapeOf()[e] != shape.at(e) && shape.at(e) != -1) + return false; } + return true; +} ////////////////////////////////////////////////////////////////////////// bool NDArray::isSameShape(const NDArray *other) const { @@ -2009,8 +1690,8 @@ bool NDArray::areSameShapeAndType(const NDArray& other) const { // returns true if these two NDArrays have same _shapeInfo // still the definition of inline function must be in header file -bool NDArray::isSameShapeStrict(const NDArray *other) const { - return shape::equalsStrict(_shapeInfo, other->_shapeInfo); +bool NDArray::isSameShapeStrict(const NDArray& other) const { + return shape::equalsStrict(_shapeInfo, other._shapeInfo); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 5adff5853..a08db4f16 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -35,21 +35,6 @@ ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; template <> ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; -////////////////////////////////////////////////////////////////////////// -template -NDArray* NDArray::asT() const{ - - auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); - auto l = this->lengthOf(); - - NDArray::prepareSpecialUse({result}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({result}, {this}); - - return result; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArray::asT, () const, LIBND4J_TYPES); - //////////////////////////////////////////////////////////////////////// // copy constructor NDArray::NDArray(const NDArray& other) { @@ -499,9 +484,7 @@ std::vector NDArray::asByteVector() { if (this->isView()) { auto tmp = this->dup(this->ordering()); syncToHost(); - memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); - - delete tmp; + memcpy(result.data(), tmp.getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); } else { syncToHost(); memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); @@ -590,26 +573,78 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, offsetThis, offsetOther); } +//////////////////////////////////////////////////////////////////// +// This method assigns values of given NDArray to this one +void NDArray::assign(const NDArray& other, bool allowParallelism) { + + if (this == &other) + return; + + if (other.isEmpty()) { + if (!isEmpty()) { + ArrayOptions::setPropertyBit(shapeInfo(), ARRAY_EMPTY); + syncShape(); + _buffer = std::make_shared(); + _offset = 0; + } + return; + } + + if(isEmpty()) { + *this = other; + return; + } + + if (other.lengthOf() == 1) { + + if(lengthOf() == 1) { + NDArray::preparePrimaryUse({this}, {&other}); + BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&other}); + this->syncToDevice(); + } + else { + if (dataType() != other.dataType()) { + auto tmp = other.cast(dataType()); + NDArray::prepareSpecialUse({this}, {&tmp}); + NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp.getBuffer(), tmp.getShapeInfo(), tmp.getSpecialBuffer(), tmp.getSpecialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {}); + } + else { + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } + } + } + else { + if (other.lengthOf() != lengthOf()) { + auto shapeThis = ShapeUtils::shapeAsString(this); + auto shapeThat = ShapeUtils::shapeAsString(&other); + nd4j_printf("Can't assign new value to the array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); + throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); + } + + // memcpy is allowed only for same order && same ews (being equal to 1) + if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) + copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); + else { + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } + } +} ////////////////////////////////////////////////////////////////////////// // This method assigns values of given NDArray to this one, wrt order - void NDArray::assign(const NDArray *other, bool allowParallelism) { - assign(*other, allowParallelism); - } - -////////////////////////////////////////////////////////////////////////// -// This method assigns given value to all elements in this NDArray -void NDArray::assign(const double value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {&temp}); +void NDArray::assign(const NDArray *other, bool allowParallelism) { + assign(*other, allowParallelism); } ////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const float value, bool allowParallelism) { +template +void NDArray::assign(const T& value, bool allowParallelism) { // just fire scalar auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); @@ -617,116 +652,19 @@ void NDArray::assign(const float value, bool allowParallelism) { NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); NDArray::registerSpecialUse({this}, {&temp}); } - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const float16 value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const bfloat16& value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const Nd4jLong value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const int value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const int16_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint8_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint16_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint32_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const uint64_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const int8_t value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::assign(const bool value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} +template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const int& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); +template ND4J_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); ////////////////////////////////////////////////////////////////////////// NDArray* NDArray::detach() { @@ -841,32 +779,7 @@ void* NDArray::bufferWithOffset(Nd4jLong offset) const { ////////////////////////////////////////////////////////////////////////// // eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - return new NDArray(reduceAlongDims(op, dimensions, keepDims, supportOldShapes)); -} - -////////////////////////////////////////////////////////////////////////// -// eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDims(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -874,13 +787,13 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::FloatOps op, const std::vectorreduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDims(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -888,13 +801,13 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -902,13 +815,13 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { std::vector copy(dimensions); @@ -916,29 +829,29 @@ NDArray NDArray::reduceAlongDims(nd4j::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray *NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray *NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray *NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); } @@ -1082,11 +995,6 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector& dimensions) cons return numTads; } -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const { - return tensorAlongDimension(index, std::vector(dimensions)); -} - ////////////////////////////////////////////////////////////////////////// void NDArray::printShapeInfo(const char * msg) const { //shape::printShapeInfo(_shapeInfo); @@ -1305,13 +1213,20 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, ////////////////////////////////////////////////////////////////////////// // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() const { +NDArray NDArray::transpose() const &{ NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset()); newArr.transposei(); return newArr; } +////////////////////////////////////////////////////////////////////////// +// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected +NDArray NDArray::transpose() && { + + this->transposei(); + return std::move(*this); +} //////////////////////////////////////////////////////////////////////// // method performs transpose operation based on this array and store result in target, this array remains unaffected @@ -1418,7 +1333,7 @@ Nd4jLong NDArray::argMax(std::initializer_list dimensions) { ////////////////////////////////////////////////////////////////////////// // create new array with corresponding order and shape, new array will point to the same _buffer as this array -NDArray NDArray::reshape(const char order, const std::vector& shape) const { +NDArray NDArray::reshape(const char order, const std::vector& shape) const & { NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset()); newArr.reshapei(order, shape); @@ -1426,6 +1341,13 @@ NDArray NDArray::reshape(const char order, const std::vector& shape) c return newArr; } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::reshape(const char order, const std::vector& shape) && { + + this->reshapei(order, shape); + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tilei(const std::vector& reps) { @@ -1490,7 +1412,7 @@ bool NDArray::permutei(const std::vector& dimensions) { } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) const { +NDArray NDArray::permute(const int* dimensions, const int rank) const & { // evaluate shapeInfo for output (permuted) array ret auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); @@ -1499,38 +1421,80 @@ NDArray NDArray::permute(const int* dimensions, const int rank) const { return ret; } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const int* dimensions, const int rank) && { + + this->permutei(dimensions, rank); + return std::move(*this); +} + ///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const { +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const &{ int tempDims[MAX_RANK]; shape::convertT(const_cast(dimensions), tempDims, rank); return permute(tempDims, rank); } +///////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { + + this->permutei(dimensions, rank); + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const { +NDArray NDArray::permute(const std::vector& dimensions) const &{ auto data = dimensions.data(); auto size = dimensions.size(); return permute(data, size); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const { +NDArray NDArray::permute(const std::vector& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) const & { return permute(dimensions.data(), dimensions.size()); } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::vector& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const { +NDArray NDArray::permute(const std::initializer_list& dimensions) const &{ + std::vector vec(dimensions); return permute(vec); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const { +NDArray NDArray::permute(const std::initializer_list& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) const & { std::vector vec(dimensions); return permute(vec); } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::permute(const std::initializer_list& dimensions) && { + + this->permutei(dimensions); + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// void NDArray::permute(const int* dimensions, const int rank, NDArray& target) const { if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) @@ -1623,7 +1587,7 @@ T* NDArray::bufferAsT() const { BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , * NDArray::bufferAsT() const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::subarray(IndicesList& idx) const { +NDArray NDArray::subarray(IndicesList& idx) const { const int idxSize = idx.size(); if (idxSize != this->rankOf()) @@ -1655,11 +1619,11 @@ NDArray* NDArray::subarray(IndicesList& idx) const { indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride } } - return new NDArray((*this)(indexes, true, true)); + return NDArray((*this)(indexes, true, true)); } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::subarray(const std::initializer_list& idx) const { +NDArray NDArray::subarray(const std::initializer_list& idx) const { const int idxSize = idx.size(); if (idxSize != this->rankOf()) @@ -1698,11 +1662,11 @@ NDArray* NDArray::subarray(const std::initializer_list& idx) const { for (auto i: idx) delete i; - return new NDArray((*this)(indexes, true, true)); + return NDArray((*this)(indexes, true, true)); } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::subarray(const Intervals& idx) const { +NDArray NDArray::subarray(const Intervals& idx) const { const int idxSize = idx.size(); if (idxSize != this->rankOf()) @@ -1723,390 +1687,47 @@ NDArray* NDArray::subarray(const Intervals& idx) const { } } - return new NDArray((*this)(indexes, true)); + return NDArray((*this)(indexes, true)); } +////////////////////////////////////////////////////////////////////////// +template +NDArray NDArray::asT() const{ + + auto result = isScalar() ? NDArray('c', {}, {0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); + auto l = this->lengthOf(); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; +} +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); + //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::asT(DataType dtype) const { +NDArray NDArray::asT(DataType dtype) const { if (isS()) throw std::runtime_error("NDArray::asT: you can't use this method on String array!"); BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); - return nullptr; + return NDArray(); } //////////////////////////////////////////////////////////////////////// -template -NDArray* NDArray::cast() { - if (isS()) - throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); - return this->asT(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray* NDArray::cast(DataType dtype) const { +NDArray NDArray::cast(DataType dtype) const { if (isS()) throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); return this->asT(dtype); } //////////////////////////////////////////////////////////////////////// -void NDArray::cast(NDArray* target, DataType dtype) { +void NDArray::cast(NDArray& target, DataType dtype) { if (isS()) throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); // TODO: to be implemented properly - target->assign(this); -} - -//////////////////////////////////////////////////////////////////////// -// addition operator array + array -NDArray NDArray::operator+(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator+: you can't use this method on String array!"); - - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (other.dataType() != DataType::BOOL) ) { - throw datatype_exception::build("NDArray::operator+: cannot add different types.", dataType(), other.dataType()); - } - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), other); -} - -//////////////////////////////////////////////////////////////////////// -// addition operator array + scalar -template -NDArray NDArray::operator+(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator+: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator+(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator+(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// subtraction operator array - scalar -template -NDArray NDArray::operator-(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator-(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator-(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// multiplication operator array*scalar -template -NDArray NDArray::operator*(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator*: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator*(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator*(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// division operator array / scalar -template -NDArray NDArray::operator/(const T& scalar) const { - if (isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - - if(scalar == (T)0.) - throw std::runtime_error("NDArray::operator/ (division operator) : division by zero !"); - - auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); - NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &tmp}); - NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &tmp}); - - return result; -} -template ND4J_EXPORT NDArray NDArray::operator/(const double& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const float& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const float16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const bfloat16& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const Nd4jLong& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const int& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const int16_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const int8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const uint8_t& scalar) const; -template ND4J_EXPORT NDArray NDArray::operator/(const bool& scalar) const; - -//////////////////////////////////////////////////////////////////////// -// addition operator scalar + array -ND4J_EXPORT NDArray operator+(const float16& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const bfloat16& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const Nd4jLong& scalar, const NDArray& arr) { - return arr + scalar; -} -ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr) { - return arr + scalar; -} - -//////////////////////////////////////////////////////////////////////// -// addition operator scalar + array -ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr) { - return arr * scalar; -} - -ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const Nd4jLong& scalar, const NDArray& arr) { - return arr * scalar; -} -ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr) { - return arr * scalar; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const float16& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const bfloat16& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const Nd4jLong& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const bfloat16& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const float16& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray & arr) { - if (arr.isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (arr.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide scalar by bool array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + target.assign(this); } //////////////////////////////////////////////////////////////////////// @@ -2133,11 +1754,11 @@ void NDArray::operator+=(const NDArray& other) { throw std::invalid_argument("NDArray::operator+=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(getShapeInfo(), bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } @@ -2167,11 +1788,11 @@ void NDArray::operator-=(const NDArray& other) { throw std::invalid_argument("NDArray::operator-=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(getShapeInfo(), bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } @@ -2200,11 +1821,11 @@ void NDArray::operator*=(const NDArray& other) { throw std::invalid_argument("NDArray::operator*=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } @@ -2237,15 +1858,16 @@ void NDArray::operator/=(const NDArray& other) { throw std::invalid_argument("NDArray::operator/=: the shapes of this and other arrays are not suitable for broadcast operation !"); if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &other, this, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), other, *this, false); } else { NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &other, &result, false); + this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), other, result, false); *this = std::move(result); // move assignment operator, zero cost copy } } } + //////////////////////////////////////////////////////////////////////// template void NDArray::operator+=(const T value) { @@ -2335,77 +1957,9 @@ template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); template ND4J_EXPORT void NDArray::operator/=(const bool scalar); -//////////////////////////////////////////////////////////////////////// -// subtraction operator array - array -NDArray NDArray::operator-(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator-: you can't use this method on String array!"); - - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw nd4j::datatype_exception::build("NDArray operator-: Cannot subtract different types", this->dataType(), other.dataType()); - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), other); -} - -//////////////////////////////////////////////////////////////////////// -// multiplication operator array*array -NDArray NDArray::operator*(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator*: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw nd4j::datatype_exception::build("NDArray operator*: Cannot multiply different types", this->dataType(), other.dataType()); - - PointersManager pointersManager(getContext(), "operator *"); - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), other); -} - -//////////////////////////////////////////////////////////////////////// -// division operator array/array -NDArray NDArray::operator/(const NDArray& other) const { - if (isS()) - throw std::runtime_error("NDArray::operator/: you can't use this method on String array!"); - if (other.isB()) - throw std::runtime_error("NDArray::operator/: you can't divide by bool array!"); - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType()) - throw nd4j::datatype_exception::build("NDArray operator/: Cannot divide different types", this->dataType(), other.dataType()); - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; - } - - return this->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), other); -} - //////////////////////////////////////////////////////////////////////// // negative operator, it makes all array elements = -elements -NDArray NDArray::operator-() const { +NDArray NDArray::operator-() const & { if (isS()) throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); @@ -2418,6 +1972,18 @@ NDArray NDArray::operator-() const { return result; } +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::operator-() && { + if (isS()) + throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + //////////////////////////////////////////////////////////////////////// // mathematical multiplication of two arrays NDArray mmul(const NDArray& left, const NDArray& right) { @@ -2430,9 +1996,9 @@ NDArray mmul(const NDArray& left, const NDArray& right) { } //////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::vector& shape, NDArray* target) { - if(target != nullptr) { - this->tile(*target); +void NDArray::tileToShape(const std::vector& shape, NDArray& target) { + if(&target != this) { + this->tile(target); return; } @@ -2457,7 +2023,7 @@ void NDArray::tileToShape(const std::vector& shape, NDArray* target) { } //////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::initializer_list& shape, NDArray* target) { +void NDArray::tileToShape(const std::initializer_list& shape, NDArray& target) { tileToShape(std::vector(shape), target); } @@ -2496,152 +2062,143 @@ double NDArray::getTrace() const { return sum; } - //////////////////////////////////////////////////////////////////////// -NDArray NDArray::quantize(NDArray &array) { - return *(quantize(&array)); -} +NDArray NDArray::quantize(const NDArray& array) { -//////////////////////////////////////////////////////////////////////// -NDArray* NDArray::quantize(NDArray *array) { - - if(array->isR()) + if(!array.isR()) throw std::invalid_argument("NDArray::quantize: type of array should be from real space!"); - auto ws = array->getContext()->getWorkspace(); + auto ws = array.getContext()->getWorkspace(); - Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array->getShapeInfo(), true, ws); + Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array.getShapeInfo(), true, ws); ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); - std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array->lengthOf()), ArrayOptions::dataType(shapeInfo), ws); + std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array.lengthOf()), ArrayOptions::dataType(shapeInfo), ws); - auto result = new NDArray(buffer, ShapeDescriptor(shapeInfo), array->getContext()); + NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast: you can't use this method on String array!"); - if(target == nullptr || other == nullptr) - throw std::runtime_error("NDArray::applyTrueBroadcast method: target or other = nullptr !"); - if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB())) + + if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other.isB()) || (op.s == scalar::ReverseDivide && this->isB())) throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); - if (isEmpty() || other->isEmpty()) + if (isEmpty() || other.isEmpty()) return; if (lengthOf() == 1) { - target->assign(this); - target->applyPairwiseTransform(op.p, *other, extraArgs); + target.assign(this); + target.applyPairwiseTransform(op.p, other, extraArgs); return; } - if (other->lengthOf() == 1) { + if (other.lengthOf() == 1) { const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); return; } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsTypesAndShapesSoft(target->getShapeInfo(), newShapeInfo)) + if(!shape::equalsTypesAndShapesSoft(target.getShapeInfo(), newShapeInfo)) throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); } - if(target->isSameShape(this) || target->isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs); + if(target.isSameShape(this) || target.isSameShape(other)) { + const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); return; } #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(dataType(), other->dataType(), target->dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(dataType(), other.dataType(), target.dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES); #endif } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - if(target == nullptr || other == nullptr) - throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !"); - if (isEmpty() || other->isEmpty()) + if (isEmpty() || other.isEmpty()) return; if (lengthOf() == 1) { - NDArray temp(target->_shapeInfo, dataType(), false, getContext()); + NDArray temp(target._shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->lengthOf() == 1) { + if (other.lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != DataType::BOOL) + if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != DataType::BOOL) throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); } - if(target->isSameShape(this) || target->isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs); + if(target.isSameShape(this) || target.isSameShape(other)) { + const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); return; } - BUILD_DOUBLE_SELECTOR(dataType(), target->dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, BOOL_TYPES); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - if(target == nullptr || other == nullptr) - throw std::runtime_error("NDArray::applyTrueBroadcast int method: target or other = nullptr !"); - if (isEmpty() || other->isEmpty()) + if (isEmpty() || other.isEmpty()) return; if (lengthOf() == 1) { - NDArray temp(target->_shapeInfo, dataType(), false, getContext()); + NDArray temp(target._shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->lengthOf() == 1) { + if (other.lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType()) + if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != this->dataType()) throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); } - if(target->isSameShape(this) || target->isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs); + if(target.isSameShape(this) || target.isSameShape(other)) { + const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); return; } - BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, *other, *target), INTEGER_TYPES); + BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, other, target), INTEGER_TYPES); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const & { if (isEmpty() || other.isEmpty()) { if (isEmpty()) return NDArray(*this); @@ -2654,19 +2211,100 @@ NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& o throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, &other, &result, false, extraArgs); + this->applyTrueBroadcast(op, other, result, false, extraArgs); return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) const & { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(!shape::shapeEquals(newShapeInfo, other.getShapeInfo())) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(!shape::shapeEquals(newShapeInfo, getShapeInfo())) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + const bool thisMove = shape::shapeEquals(newShapeInfo, getShapeInfo()); + const bool otherMove = shape::shapeEquals(newShapeInfo, other.getShapeInfo()); + + if(!thisMove && !otherMove) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + if(thisMove) { + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); + } + + // otherMove + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { if (isS()) throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); - if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other->isB()) || (op == broadcast::ReverseDivide && this->isB())) + if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!"); - if(isEmpty() || other->isEmpty()) { - if(!target->isEmpty()) + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) throw std::runtime_error("NDArray::applyBroadcast method: when some of input arrays (or both) is empty, target array must be empty as well !"); return; } @@ -2674,28 +2312,26 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& di if (dimensions.size() == 0) return; - auto result = target == nullptr ? this : target; - - if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {this, other}); + if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); return; } NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { max = this; - min = const_cast(other); + min = const_cast(&other); } else { - max = const_cast(other); + max = const_cast(&other); min = this; } - if(result->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other->getShapeInfo())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.getShapeInfo())) throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); - if(!result->isSameShape(max)) + if(!target.isSameShape(max)) throw std::invalid_argument("NDArray::applyBroadcast method: max and target arrays must have the same shape !"); std::vector copy(dimensions); @@ -2708,22 +2344,22 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector& di throw std::runtime_error("NDArray::applyBroadcast method: tad length mismatch !"); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); if(max == this) - NativeOpExecutioner::execBroadcast( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcast( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcast(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - registerSpecialUse({result}, {this, other}); + NativeOpExecutioner::execInverseBroadcast(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { +void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { if (isS()) throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); - if(isEmpty() || other->isEmpty()) { - if(!target->isEmpty()) + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) throw std::runtime_error("NDArray::applyBroadcast BoolOps: when some of input arrays (or both) is empty, target array must be empty as well !"); return; } @@ -2731,30 +2367,28 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector if (dimensions.size() == 0) return; - auto result = target == nullptr ? this : target; - - if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {this, other}); + if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); return; } NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { max = this; - min = const_cast(other); + min = const_cast(&other); } else { - max = const_cast(other); + max = const_cast(&other); min = this; } - if(result->dataType() != DataType::BOOL) + if(target.dataType() != DataType::BOOL) throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); - if(!result->isSameShape(max)) + if(!target.isSameShape(max)) throw std::invalid_argument("NDArray::applyBroadcast bool method: max and target arrays must have the same shape !"); - if(_dataType != other->_dataType) + if(_dataType != other._dataType) throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); std::vector copy(dimensions); @@ -2767,24 +2401,24 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector throw std::runtime_error("Tad length mismatch"); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); // TODO: eventually we want separate tads here - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); if(max == this) - NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - registerSpecialUse({result}, {this, other}); + NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { +void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { if (!isZ()) throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); - if(isEmpty() || other->isEmpty()) { - if(!target->isEmpty()) + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); return; } @@ -2792,30 +2426,28 @@ void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& if (dimensions.empty()) return; - auto result = target == nullptr ? this : target; - - if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {this, other}); + if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); return; } NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { max = this; - min = const_cast(other); + min = const_cast(&other); } else { - max = const_cast(other); + max = const_cast(&other); min = this; } - if(result->dataType() != dataType()) + if(target.dataType() != dataType()) throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); - if(!result->isSameShape(max)) + if(!target.isSameShape(max)) throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !"); - if(_dataType != other->_dataType) + if(_dataType != other._dataType) throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); std::vector copy(dimensions); @@ -2828,76 +2460,23 @@ void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& throw std::runtime_error("Tad length mismatch"); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); // TODO: eventually we want separate tads here - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); if(max == this) - NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - registerSpecialUse({result}, {this, other}); + NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) { +void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tadArray, NDArray& target, ExtraArguments* extraArgs) { std::vector vec(dimensions); applyBroadcast(op, vec, tadArray, target, extraArgs); } -////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, ExtraArguments *extraArgs) const { - return new NDArray(this->applyTrueBroadcast(op, *other, extraArgs)); -} - -////////////////////////////////////////////////////////////////////////// -// return array which is broadcasted from this and argument array -NDArray* NDArray::broadcast(const NDArray& other) { - // the orders must be the same - char order = ordering(); - if(order != other.ordering()) - throw std::runtime_error("NDArray::broadcast method: arrays have different orders!"); - - // recognize shapes with smaller and bigger rank - Nd4jLong* biggerShapeInfo = nullptr; - Nd4jLong* smallerShapeInfo = nullptr; - int smallerRank, biggerRank; - if (rankOf() > other.rankOf()) { - biggerShapeInfo = _shapeInfo; - biggerRank = shape::rank(_shapeInfo); - smallerShapeInfo = other._shapeInfo; - smallerRank = shape::rank(other._shapeInfo); - } - else { - biggerShapeInfo = other._shapeInfo; - biggerRank = shape::rank(other._shapeInfo); - smallerShapeInfo = _shapeInfo; - smallerRank = shape::rank(_shapeInfo); - } - - // check shapes on consistency - int diff = biggerRank - smallerRank; - for (int i = smallerRank; i<=1; --i) - if(biggerShapeInfo[diff+i] != smallerShapeInfo[i] && biggerShapeInfo[i] != 1 && smallerShapeInfo[i] != 1) - throw std::runtime_error("Broadcast method: arrays have incompatible shapes !"); - - // create and fill ret shapeInfo - Nd4jLong *shapeInfoNew; - ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(biggerRank), Nd4jLong); - memcpy(shapeInfoNew, biggerShapeInfo, shape::shapeInfoByteLength(biggerRank)); - for (int i = smallerRank; i>=1; --i) - if(shapeInfoNew[diff+i] == 1 || smallerShapeInfo[i] == 1) - shapeInfoNew[diff+i] *= smallerShapeInfo[i]; - - ShapeUtils::updateStridesAndType(shapeInfoNew, DataTypeUtils::pickPairwiseResultType(dataType(), other.dataType()), order); - - auto ret = new NDArray(shapeInfoNew, true, getContext()); - - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - - return ret; -} - //////////////////////////////////////////////////////////////////////// void* NDArray::operator new(size_t i) { if (nd4j::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) { @@ -3017,7 +2596,7 @@ bool NDArray::reshapei(const char order, const std::vector& cshape) { } else { NDArray temp(order, shape, dataType(), getContext()); - this->applyTransform(transform::Assign, &temp, nullptr); + this->applyTransform(transform::Assign, temp, nullptr); *this = std::move(temp); } @@ -3049,57 +2628,57 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, nd4j::DataType BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* other, NDArray *target, ExtraArguments *extraParams) const{ +void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ if (isS()) throw std::runtime_error("NDArray::applyPairwiseTransform: you can't use this method on String array!"); - if (other->lengthOf() != target->lengthOf()) + if (other.lengthOf() != target.lengthOf()) throw std::invalid_argument("NDArray::applyPairwiseTransform method - lengths of arrays are mismatched"); - if (target->dataType() != this->dataType() && target->dataType() != other->dataType()) + if (target.dataType() != this->dataType() && target.dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); - NDArray::prepareSpecialUse({target}, {this, other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); - NDArray::registerSpecialUse({target}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); if (extraParams != nullptr) synchronize("NDArray::applyPairwiseTransform"); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{ +void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ if (isS()) throw std::runtime_error("NDArray::applyPairwiseTransform BoolOps: you can't use this method on String array!"); - if (other->lengthOf() != target->lengthOf()) + if (other.lengthOf() != target.lengthOf()) throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - lengths of arrays are mismatched"); - if (!target->isB()) + if (!target.isB()) throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - result must have bool type"); - if (dataType() != other->dataType()) + if (dataType() != other.dataType()) throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); - NDArray::prepareSpecialUse({target}, {this, other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); - NDArray::registerSpecialUse({target}, {this, other}); + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); } //////////////////////////////////////////////////////////////////////// - void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); - if (other->lengthOf() != target->lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); - if (!target->isZ()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); - if (dataType() != other->dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); +void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); + if (!target.isZ()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); + if (dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); - NDArray::prepareSpecialUse({target}, {this, other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); - NDArray::registerSpecialUse({target}, {this, other}); - } + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); +} ////////////////////////////////////////////////////////////////////////// void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { - applyPairwiseTransform(op, &other, this, extraParams); + applyPairwiseTransform(op, other, *this, extraParams); } //////////////////////////////////////////////////////////////////////// @@ -3112,41 +2691,31 @@ void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray *target, const bool biasCorrected, const std::vector& dimensions) const { +void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { if (isS()) throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - if (!target->isR()) + if (!target.isR()) throw std::runtime_error("NDArray::varianceAlongDimension: target array must have FLOAT type"); - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == dimensions.size() || dimensions.empty()) - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), biasCorrected); + NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), biasCorrected); else { std::vector copy(dimensions); auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->buffer(), target->shapeInfo(), target->getSpecialBuffer(), target->specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); + NativeOpExecutioner::execSummaryStats(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.getSpecialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); synchronize("NDArray::varianceAlongDimension"); } - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { - return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray *target, const bool biasCorrected, const std::initializer_list& dimensions) const { - varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { +NDArray NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { if (isS()) throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); @@ -3157,87 +2726,27 @@ NDArray NDArray::varianceAlongDims(nd4j::variance::Ops op, const bool biasCorrec auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); NDArray result(newShape, true, getContext()); - this->varianceAlongDimension(op, &result, biasCorrected, dimensions); + this->varianceAlongDimension(op, result, biasCorrected, dimensions); return result; } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { - - return new NDArray(this->varianceAlongDims(op, biasCorrected, dimensions)); +NDArray NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { + return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); } -//////////////////////////////////////////////////////////////////// -// This method assigns values of given NDArray to this one -void NDArray::assign(const NDArray& other, bool allowParallelism) { - - if (this == &other) - return; - - if (other.isEmpty()) { - if (!isEmpty()) { - ArrayOptions::setPropertyBit(shapeInfo(), ARRAY_EMPTY); - syncShape(); - _buffer = std::make_shared(); - _offset = 0; - } - return; - } - - if(isEmpty()) { - *this = other; - return; - } - - if (other.lengthOf() == 1) { - - if(lengthOf() == 1) { - NDArray::preparePrimaryUse({this}, {&other}); - BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&other}); - this->syncToDevice(); - } - else { - if (dataType() != other.dataType()) { - auto tmp = other.cast(dataType()); - NDArray::prepareSpecialUse({this}, {tmp}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {}); - delete tmp; - } - else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } - } - } - else { - if (other.lengthOf() != lengthOf()) { - auto shapeThis = ShapeUtils::shapeAsString(this); - auto shapeThat = ShapeUtils::shapeAsString(&other); - nd4j_printf("Can't assign new value to the array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); - throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); - } - - // memcpy is allowed only for same order && same ews (being equal to 1) - if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) - copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); - else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } - } +//////////////////////////////////////////////////////////////////////// +void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray &target, const bool biasCorrected, const std::initializer_list& dimensions) const { + varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// // This method returns new copy of this NDArray, optionally in different order -NDArray* NDArray::dup(const char newOrder) const { +NDArray NDArray::dup(const char newOrder) const { if (isEmpty()) - return NDArrayFactory::empty_(dataType(), getContext()); + return NDArrayFactory::empty(dataType(), getContext()); char order = newOrder == 'a' ? ordering() : newOrder; @@ -3248,12 +2757,12 @@ NDArray* NDArray::dup(const char newOrder) const { for (int e = 0; e < lengthOf(); e++) strings[e] = this->e(e); - auto result = NDArrayFactory::string_(order, getShapeAsVector(), strings, getContext()); + auto result = NDArrayFactory::string(order, getShapeAsVector(), strings, getContext()); return result; } - auto result = new NDArray(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); - result->assign(*this); + NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); + result.assign(*this); return result; } @@ -3432,87 +2941,72 @@ NDArray NDArray::e(const Nd4jLong i) const { ////////////////////////////////////////////////////////////////////////// // perform array transformation -void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform FloatOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (!target->isR()) + if (!target.isR()) throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform AnyOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform SameOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (target->dataType() != dataType()) + if (target.dataType() != dataType()) throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform StrictOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (!this->isR() || !target->isR() || (this->dataType() != target->dataType())) + if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray *target, ExtraArguments *extraParams) { +void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyTransform BoolOps: you can't use this method on String array!"); - if (target == nullptr) - target = this; - - if (!target->isB()) + if (!target.isB()) throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); - NDArray::prepareSpecialUse({target}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) const & { if (isS()) throw std::runtime_error("NDArray::transform FloatOps: you can't use this method on String array!"); @@ -3526,7 +3020,19 @@ NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) cons } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const & { if (isS()) throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); @@ -3540,7 +3046,19 @@ NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) const & { if (!this->isR()) throw std::runtime_error("Source array must have one of FLOAT types"); @@ -3554,7 +3072,19 @@ NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) con } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const { +NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) && { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const & { if (isS()) throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); @@ -3567,151 +3097,159 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const return result; } +//////////////////////////////////////////////////////////////////////// +NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) { +void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (scalar->lengthOf() != 1) + if (scalar.lengthOf() != 1) throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); - if(target == nullptr) - target = this; - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType())) + + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar.getShapeInfo()) && !(target.dataType() == dataType() || target.dataType() == scalar.dataType())) throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); - NDArray::prepareSpecialUse({target}, {this, scalar}); - NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); - NDArray::registerSpecialUse({target}, {this, scalar}); + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.getBuffer(), scalar.getShapeInfo(), scalar.getSpecialBuffer(), scalar.getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); } -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray *target, ExtraArguments *extraParams) { - - auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); - applyScalarArr(op, &scalarArr, target, extraParams); -} - -template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray *target, ExtraArguments *extraParams); - ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { +void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyScalarArr BoolOps: you can't use this method on String array!"); - if (target == nullptr || !target->isB()) - throw std::invalid_argument("NDArray::applyScalarArr bool method: target is nullptr or has not bool type!"); - if (dataType() != scalar->dataType()) { - nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType()); + if (!target.isB()) + throw std::invalid_argument("NDArray::applyScalarArr bool method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); } - NDArray::prepareSpecialUse({target}, {this, scalar}); - NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); - NDArray::registerSpecialUse({target}, {this, scalar}); + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.getBuffer(), scalar.getShapeInfo(), scalar.getSpecialBuffer(), scalar.getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); + + if (target.dataType() != this->dataType()) + throw std::invalid_argument("NDArray::applyScalarArr int method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); + throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.getBuffer(), scalar.getShapeInfo(), scalar.getSpecialBuffer(), scalar.getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); } //////////////////////////////////////////////////////////////////////// template -void NDArray::applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const { +void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams) const { - NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); - applyScalarArr(op, &scalarArr, target, extraParams); + NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); } -template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; - - -////////////////////////////////////////////////////////////////////////// - void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); - - if (target == nullptr || target->dataType() != this->dataType()) - throw std::invalid_argument("NDArray::applyScalarArr int method: target is nullptr or has not bool type!"); - if (dataType() != scalar->dataType()) { - nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType()); - throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); - } - - NDArray::prepareSpecialUse({target}, {this, scalar}); - NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); - NDArray::registerSpecialUse({target}, {this, scalar}); - } +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; //////////////////////////////////////////////////////////////////////// - template - void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); - applyScalarArr(op, &scalarArr, target, extraParams); - } - - template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; +template +void NDArray::applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams) { + auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); //////////////////////////////////////////////////////////////////////// -void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector& dimensions, const ExtraArguments *extraParams) const { +template +void NDArray::applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray &target, ExtraArguments *extraParams) const { + + NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} + +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; + +//////////////////////////////////////////////////////////////////////// +void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - if (target->dataType() != nd4j::DataType::INT64 && target->dataType() != nd4j::DataType::INT32) + if (target.dataType() != nd4j::DataType::INT64 && target.dataType() != nd4j::DataType::INT32) throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); - if (target->lengthOf() == 1) { - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo()); + if (target.lengthOf() == 1) { + NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); } else { std::vector copy = dimensions; shape::checkDimensions(rankOf(), copy); auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - NativeOpExecutioner::execIndexReduce(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execIndexReduce(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); synchronize("NDArray::applyIndexReduce"); } - registerSpecialUse({target}, {this}); + registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// // reduce dimensions in this array relying on index operations -NDArray* NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { +NDArray NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { std::vector copy = dimensions; auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, false, getContext()->getWorkspace()); - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); applyIndexReduce(op, result, copy, extraParams); @@ -3720,10 +3258,11 @@ NDArray* NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, const std::vector< //////////////////////////////////////////////////////////////////////// // apply reduce3 operations to this and other array, return result in new output array -NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const ExtraArguments* extraParams) const { +NDArray NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams) const { + if (isS()) throw std::runtime_error("NDArray::applyReduce3 method: you can't use this method on String array!"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); // check shapes consistency if(!isSameShape(other)) @@ -3731,75 +3270,75 @@ NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, cons // create shapeInfo for scalar auto newShape = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()), getContext()->getWorkspace()); // create output array (scalar) - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); RELEASE(newShape, getContext()->getWorkspace()); // create dynamic array of extra parameters if array extraParams is empty (==nullptr) void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo()); - NDArray::registerSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this, &other}); return result; } //////////////////////////////////////////////////////////////////////// // apply reduce3 (exec) operations to this and other array, return result in new output array -NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, const std::vector& dimensions, const ExtraArguments* extraParams) const { +NDArray NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyReduce3: you can't use this method on String array!"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other->rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); // create temporary dynamic array of extra parameters if array extraParams is empty (==nullptr) void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - NDArray::prepareSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); // perform calculations - if(rankOf() == copy.size() && other->rankOf() == copy.size()) { - NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo()); + if(rankOf() == copy.size() && other.rankOf() == copy.size()) { + NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); } else { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(other->getShapeInfo(), copy); + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(other.getShapeInfo(), copy); if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); - NativeOpExecutioner::execReduce3(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NativeOpExecutioner::execReduce3(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); } - registerSpecialUse({result}, {this, other}); + registerSpecialUse({&result}, {this, &other}); return result; } //////////////////////////////////////////////////////////////////////// // apply reduce3 (execAll) operations to this and other array, return result in new output array -NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, const std::vector& dimensions, const ExtraArguments* extraParams) const { +NDArray NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyAllReduce3: you can't use this method on String array!"); - if(dataType() != other->dataType()) + if(dataType() != other.dataType()) throw std::runtime_error("NDArray::applyAllReduce3 method: the types of this and other arrays must be the same !"); // be careful, copy array may undergo changes (sort, transformation of negative dimensions to positive, duplicates removing ) std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other->rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(other->getShapeInfo(), copy); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions(other.getShapeInfo(), copy); // check tads shapes if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) @@ -3809,145 +3348,145 @@ NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, c auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); // create output array - auto result = new NDArray(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); // create dynamic array of extra parameters if array extraParams is empty (==nullptr) void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - NDArray::prepareSpecialUse({result}, {this, other}); - NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({result}, {this, other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NDArray::registerSpecialUse({&result}, {this, &other}); return result; } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!"); - if (target == nullptr || !target->isR()) + if (!target.isR()) throw std::invalid_argument("NDArray::reduceAlongDimension FloatOps: requires target array to be present and have type form real space!"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(),nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(),nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(getShapeInfo(), copy); - NativeOpExecutioner::execReduceFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension FloatOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!"); - if (target == nullptr || target->dataType() != dataType()) + if (target.dataType() != dataType()) throw std::runtime_error("NDArray::reduceAlongDimension SameOps: requires target array to be present and have same dtype as input"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { //if (!isEmpty()) { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension SameOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!"); - if (target == nullptr || target->dataType() != DataType::INT64) + if (target.dataType() != DataType::INT64) throw std::runtime_error("NDArray::reduceAlongDimension LongOps: requires target array to be present and have type of INT64"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - NativeOpExecutioner::execReduceLong(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceLong(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension LongOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray* target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { +void NDArray::reduceAlongDimension(nd4j::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { if (isS()) throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!"); - if (target == nullptr || !target->isB()) + if (!target.isB()) throw std::invalid_argument("NDArray::reduceAlongDimension BoolOps cuda: requires target array to be present and have BOOL type!"); std::vector copy(dimensions); if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target->ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target->getShapeInfo())) + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.getShapeInfo())) throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); } - NDArray::prepareSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); + NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo()); } else { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - NativeOpExecutioner::execReduceBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + NativeOpExecutioner::execReduceBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); } synchronize("NDArray::reduceAlongDimension LongOps"); - NDArray::registerSpecialUse({target}, {this}); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// @@ -4102,152 +3641,152 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) { } ////////////////////////////////////////////////////////////////////////// -void NDArray::addRowVector(const NDArray *row, NDArray *target) const { +void NDArray::addRowVector(const NDArray& row, NDArray& target) const { if (isS()) throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->lengthOf()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType()) && !(isR() && row->isR() && target->isR())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::subRowVector(const NDArray *row, NDArray *target) const { +void NDArray::subRowVector(const NDArray& row, NDArray& target) const { if (isS()) throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->lengthOf()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType()) && !(isR() && row->isR() && target->isR())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::mulRowVector(const NDArray *row, NDArray *target) const { +void NDArray::mulRowVector(const NDArray &row, NDArray &target) const { if (isS()) throw std::runtime_error("NDArray::mulRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->columns()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) throw std::invalid_argument("NDArray::mulRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::divRowVector(const NDArray *row, NDArray *target) const { +void NDArray::divRowVector(const NDArray &row, NDArray &target) const { if (isS()) throw std::runtime_error("NDArray::divRowVector: you can't use this method on String array!"); - if (row->isB()) + if (row.isB()) throw std::runtime_error("NDArray::divRowVector: you can't divide by bool row!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !row->isRowVector() || columns() != row->columns()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row->dataType())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) throw std::invalid_argument("NDArray::divRowVector: wrong type of target array !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, row}); + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// // This method adds given row to all rows in this NDArray, this array becomes affected -void NDArray::addiRowVector(const NDArray *row) { +void NDArray::addiRowVector(const NDArray& row) { if (isS()) throw std::runtime_error("NDArray::addiRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || !row->isRowVector() || columns() != row->lengthOf()) + if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); int dimension = 1; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {row}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row->getBuffer(), row->getShapeInfo(), row->getSpecialBuffer(), row->getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {row}); + NDArray::prepareSpecialUse({this}, {&row}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), row.getBuffer(), row.getShapeInfo(), row.getSpecialBuffer(), row.getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&row}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::addColumnVector(const NDArray *column, NDArray *target) const { +void NDArray::addColumnVector(const NDArray &column, NDArray &target) const { if (isS()) throw std::runtime_error("NDArray::addColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || target->rankOf() != 2 || rows() != target->rows() || columns() != target->columns() || !column->isColumnVector() || rows() != column->lengthOf()) + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !column.isColumnVector() || rows() != column.lengthOf()) throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); - if(target->dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column->dataType())) + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) throw std::invalid_argument("NDArray::addColumnVector: wrong type of target array !"); int dimension = 0; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({target}, {this, column}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column->getBuffer(), column->getShapeInfo(), column->getSpecialBuffer(), column->getSpecialShapeInfo(), target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({target}, {this, column}); + NDArray::prepareSpecialUse({&target}, {this, &column}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column.getBuffer(), column.getShapeInfo(), column.getSpecialBuffer(), column.getSpecialShapeInfo(), target.getBuffer(), target.getShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &column}); } ////////////////////////////////////////////////////////////////////////// // This method adds given column to all columns in this NDArray, this array becomes affected -void NDArray::addiColumnVector(const NDArray *column) { +void NDArray::addiColumnVector(const NDArray &column) { if (isS()) throw std::runtime_error("NDArray::addiColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column->isColumnVector() || rows() != column->lengthOf()) + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); int dimension = 0; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {column}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column->getBuffer(), column->getShapeInfo(), column->getSpecialBuffer(), column->getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {column}); + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Add, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column.getBuffer(), column.getShapeInfo(), column.getSpecialBuffer(), column.getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&column}); } ////////////////////////////////////////////////////////////////////////// // This method multiplies each column of this array by given argument-column, this array becomes affected -void NDArray::muliColumnVector(const NDArray *column) { +void NDArray::muliColumnVector(const NDArray& column) { if (isS()) throw std::runtime_error("NDArray::muliColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column->isColumnVector() || rows() != column->lengthOf()) + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); int dimension = 0; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {column}); - NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column->getBuffer(), column->getShapeInfo(), column->getSpecialBuffer(), column->getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {column}); + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast(getContext(), nd4j::broadcast::Ops::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), column.getBuffer(), column.getShapeInfo(), column.getSpecialBuffer(), column.getSpecialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&column}); } ////////////////////////////////////////////////////////////////////////// @@ -4278,8 +3817,8 @@ bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { - auto result = new ResultSet(); +ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { + ResultSet result; if (indices.size() == 0) return result; @@ -4296,19 +3835,19 @@ ResultSet* NDArray::multipleTensorsAlongDimension(const std::vector &indice } auto array = new NDArray(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); - result->push_back(array); + result.push_back(array); } return result; } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { +ResultSet NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { return allTensorsAlongDimension(std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::allExamples() const { +ResultSet NDArray::allExamples() const { std::vector dimensions(rankOf() - 1); for (int e = 1; e < rankOf(); e++) dimensions[e-1] = e; @@ -4338,7 +3877,7 @@ NDArray NDArray::ulike() { } //////////////////////////////////////////////////////////////////////// -NDArray* NDArray::diagonal(const char type) const { +NDArray NDArray::diagonal(const char type) const { if (isS()) throw std::runtime_error("NDArray::diagonal: you can't use this method on String array!"); @@ -4386,7 +3925,7 @@ NDArray* NDArray::diagonal(const char type) const { ArrayOptions::setDataType(outShapeInfo, this->dataType()); - auto result = new NDArray(_buffer, ShapeDescriptor(outShapeInfo), getContext(), getBufferOffset()); + NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), getBufferOffset()); RELEASE(outShapeInfo, getContext()->getWorkspace()); @@ -4394,9 +3933,9 @@ NDArray* NDArray::diagonal(const char type) const { } //////////////////////////////////////////////////////////////////////// -ResultSet* NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { +ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { - auto result = new ResultSet(); + ResultSet result; if(dimensions.size() == 0) return result; @@ -4411,14 +3950,14 @@ ResultSet* NDArray::allTensorsAlongDimension(const std::vector &dimensions) for (int idx = 0; idx < numTads; idx++ ) { auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); array->_isView = true; - result->push_back(array); + result.push_back(array); } return result; } ////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const { +NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& dimensions) const { std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); @@ -4430,12 +3969,17 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& d auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); - auto array = new NDArray(_buffer, ShapeDescriptor(packX.primaryShapeInfo()), getContext(), packX.primaryOffsets()[index] + getBufferOffset()); - array->_isView = true; + NDArray array(_buffer, ShapeDescriptor(packX.primaryShapeInfo()), getContext(), packX.primaryOffsets()[index] + getBufferOffset()); + array._isView = true; return array; } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::initializer_list& dimensions) const { + return tensorAlongDimension(index, std::vector(dimensions)); +} + //////////////////////////////////////////////////////////////////////// // operator returns sub-array with buffer pointing at this->_buffer + certain offset NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { @@ -4606,6 +4150,539 @@ void NDArray::setShapeInfo(const ConstantDataBuffer& shapeBuffer) { _dataType = ArrayOptions::dataType(_shapeInfo); } +/////////////////////////////////////////////////////////////////////// +// addition operator array + scalar +template +NDArray operator+(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr + scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Add, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator+(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Add, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const T& scalar, NDArray&& arr) { + return std::move(arr) + scalar; +} +template ND4J_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); + + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator+(const T& scalar, const NDArray& arr) { + return arr + scalar; +} +template ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +// addition operator array - scalar +template +NDArray operator-(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr - scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Subtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator-(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Subtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const T& scalar, NDArray&& arr) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(scalar - arr); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator-(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.getBuffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); + +} +template ND4J_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator-(const T& scalar, const NDArray& arr) { + + if (arr.isS()) + throw std::runtime_error("operator-(const T& scalar, const NDArray& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseSubtract, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +// addition operator array + scalar +template +NDArray operator*(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr * scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Multiply, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); +template ND4J_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator*(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Multiply, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} + +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); +template ND4J_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const T& scalar, NDArray&& arr) { + return std::move(arr) * scalar; +} +template ND4J_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); + + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator*(const T& scalar, const NDArray& arr) { + return arr * scalar; +} +template ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); + +/////////////////////////////////////////////////////////////////////// +template +NDArray operator/(NDArray&& arr, const T& scalar) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr / scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Divide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.buffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); +} +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const NDArray& arr, const T& scalar) { + + if (arr.isS()) + throw std::runtime_error("operator/(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::Divide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); +template ND4J_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const T& scalar, NDArray&& arr) { + + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(scalar / arr); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error("operator/(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), arr.getBuffer(), arr.getShapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); + +} +template ND4J_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); +template ND4J_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); + + +//////////////////////////////////////////////////////////////////////// +template +NDArray operator/(const T& scalar, const NDArray& arr) { + + if (arr.isS()) + throw std::runtime_error("operator/(const T& scalar, const NDArray& arr): you can't use this method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), nd4j::scalar::ReverseDivide, arr.getBuffer(), arr.getShapeInfo(), arr.getSpecialBuffer(), arr.getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; +} +template ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); +template ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); + +//////////////////////////////////////////////////////////////////////// +// addition operator array + array +template +NDArray operator+(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Add, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// addition operator array - array +template +NDArray operator-(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Subtract, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// multiplication operator array*array +template +NDArray operator*(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Multiply, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); + +//////////////////////////////////////////////////////////////////////// +// multiplication operator array*array +template +NDArray operator/(T1&& arr1, T2&& arr2) { + + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw nd4j::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.getShapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.getShapeInfo(), arr2.getShapeInfo()), false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), nd4j::pairwise::Divide, arr1.getBuffer(), arr1.getShapeInfo(), arr1.getSpecialBuffer(), arr1.getSpecialShapeInfo(), arr2.getBuffer(), arr2.getShapeInfo(), arr2.getSpecialBuffer(), arr2.getSpecialShapeInfo(), result->buffer(), result->getShapeInfo(), result->specialBuffer(), result->getSpecialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), std::forward(arr2)); +} +template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); +template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); +template ND4J_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); +template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); + /* #ifndef __CLION_IDE__ diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index ef45a3e0c..2190afbf1 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -104,7 +104,7 @@ namespace graph { if (node->id() == 13) nd4j_debug("",""); - // if true - this is special case: Graph-in-Graph. + // if true - this is special case: Graph-in-Graph. if (node->hasGraphEmbedded()) { auto embedded = node->getGraph(); @@ -128,12 +128,12 @@ namespace graph { int cnt = 0; for (Variable* v: *embedded->getPlaceholders()) { if (v->getName() != nullptr && v->getName()->size() > 0) { - + // trying symbolic lookup first if (variableSpace->hasVariable(v->getName())) { // symbolic feeder auto array = variableSpace->getVariable(v->getName())->getNDArray(); - auto vr = array->dup(); + auto vr = new NDArray(array->dup()); // deletables.push_back(vr); v->setNDArray(vr); } else { @@ -145,7 +145,7 @@ namespace graph { // if we're not using symbolic lookup - we'll use sequential approach then auto p = node->input()->at(cnt); auto array = variableSpace->getVariable(p)->getNDArray(); - auto vr = array->dup(); + auto vr = new NDArray(array->dup()); //deletables.push_back(vr); v->setNDArray(vr); } @@ -501,7 +501,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) } /** - * This method is provided for IPC: + * This method is provided for IPC: * 1) it accepts pointer to FlatBuffers buffer * 2) restores Graph from it * 3) Executes this Graph diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index dc9d09231..9dd2ed967 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -71,44 +71,41 @@ void NDArray::makeBothBuffersActual() const { } //////////////////////////////////////////////////////////////////////// template -void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { +void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { if (isS()) throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!"); - if(target == nullptr) - target = this; - - if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1))) + if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !"); if (direction == 'u') - lower = -target->sizeAt(-2); + lower = -target.sizeAt(-2); else if (direction == 'l') - upper = target->sizeAt(-1); + upper = target.sizeAt(-1); const T value = static_cast(val); const auto x = reinterpret_cast(getBuffer()); - auto z = reinterpret_cast(target->getBuffer()); + auto z = reinterpret_cast(target.getBuffer()); const int xRank = rankOf(); - const int zRank = target->rankOf(); + const int zRank = target.rankOf(); - const auto zLen = target->lengthOf(); + const auto zLen = target.lengthOf(); - const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo()); + const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo()); auto func = PRAGMA_THREADS_FOR { Nd4jLong coords[MAX_RANK]; for (auto i = start; i < stop; i += increment) { - shape::index2coords(i, target->getShapeInfo(), coords); - const auto zOffset = shape::getOffset(target->getShapeInfo(), coords); + shape::index2coords(i, target.getShapeInfo(), coords); + const auto zOffset = shape::getOffset(target.getShapeInfo(), coords); // if( (row + upper < col) || (row + lower > col) ) if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) z[zOffset] = value; - else if (this != target) { // when this and target are different arrays + else if (this != &target) { // when this and target are different arrays if (xRank != zRank) coords[0] = coords[1]; @@ -120,7 +117,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char samediff::Threads::parallel_for(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::setIdentity() { @@ -405,11 +402,11 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector& repeats) const { +NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { - auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, *output, repeats, axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, output, repeats, axis), LIBND4J_TYPES); return output; } diff --git a/libnd4j/blas/cpu/NDArrayLambda.hpp b/libnd4j/blas/cpu/NDArrayLambda.hpp index 6ce8e6823..86d798efc 100644 --- a/libnd4j/blas/cpu/NDArrayLambda.hpp +++ b/libnd4j/blas/cpu/NDArrayLambda.hpp @@ -2,35 +2,24 @@ template -void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target) { - if (second == nullptr) { - nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n",""); - throw std::runtime_error("second is null"); - } - - if (third == nullptr) { - nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n",""); - throw std::runtime_error("third is null"); - } if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyTriplewiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType()) + if(dataType() != second.dataType() || dataType() != third.dataType() || dataType() != target.dataType()) throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); - if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { + if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } auto f = this->bufferAsT(); - auto s = second->bufferAsT(); - auto t = third->bufferAsT(); - auto z = target->bufferAsT(); + auto s = second.bufferAsT(); + auto t = third.bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { + if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -44,8 +33,8 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std:: auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto tOffset = this->getOffset(e); - auto uOffset = second->getOffset(e); - auto vOffset = third->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); } @@ -57,9 +46,9 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std:: auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto tOffset = this->getOffset(e); - auto uOffset = second->getOffset(e); - auto vOffset = third->getOffset(e); - auto zOffset = target->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); } @@ -69,46 +58,39 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std:: } } } -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; - - if (other == nullptr) { - nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); - throw std::runtime_error("Other is null"); - } +void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target) { if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != other->dataType() || dataType() != target->dataType()) + if(dataType() != other.dataType() || dataType() != target.dataType()) throw std::runtime_error("NDArray::applyPairwiseLambda method: all three arrays (this, other, target) must have the same type !"); - if (this->lengthOf() != other->lengthOf()) { + if (this->lengthOf() != other.lengthOf()) { nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } auto f = this->bufferAsT(); - auto s = other->bufferAsT(); - auto z = target->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { + if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -122,7 +104,7 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); + auto yOffset = other.getOffset(e); f[xOffset] = func(f[xOffset], s[yOffset]); } @@ -134,8 +116,8 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); - auto zOffset = target->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(f[xOffset], s[yOffset]); } @@ -145,35 +127,33 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyLambda(const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyLambda(const std::function& func, NDArray& target) { if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) + if(dataType() != target.dataType()) throw std::runtime_error("NDArray::applyLambda method: types of this and target array should match !"); auto f = this->bufferAsT(); - auto z = target->bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { + if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -198,7 +178,7 @@ void NDArray::applyLambda(const std::function& func, NDArray* target) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto xOffset = this->getOffset(e); - auto zOffset = target->getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(f[xOffset]); } @@ -208,35 +188,33 @@ void NDArray::applyLambda(const std::function& func, NDArray* target) { } } } -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedLambda(const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyIndexedLambda(const std::function& func, NDArray& target) { if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyIndexedLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) + if(dataType() != target.dataType()) throw std::runtime_error("NDArray::applyIndexedLambda method: types of this and target array should match !"); auto f = this->bufferAsT(); - auto z = target->bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { + if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -261,7 +239,7 @@ void NDArray::applyIndexedLambda(const std::function& func, NDAr auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto xOffset = this->getOffset(e); - auto zOffset = target->getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func(e, f[xOffset]); } @@ -271,44 +249,38 @@ void NDArray::applyIndexedLambda(const std::function& func, NDAr } } } -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target) { - if (other == nullptr) { - nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); - throw std::runtime_error("Other is null"); - } if(dataType() != DataTypeUtils::fromT()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) + if(dataType() != target.dataType()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: types of this and target array should match !"); - if (this->lengthOf() != other->lengthOf()) { + if (this->lengthOf() != other.lengthOf()) { nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } auto f = this->bufferAsT(); - auto s = other->bufferAsT(); - auto z = target->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); - if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { + if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { auto loop = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) @@ -322,7 +294,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); + auto yOffset = other.getOffset(e); f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); } @@ -334,8 +306,8 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::functiongetOffset(e); - auto yOffset = other->getOffset(e); - auto zOffset = target->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); } @@ -345,16 +317,16 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); \ No newline at end of file +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); \ No newline at end of file diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index e790c05d0..c01343818 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2717,25 +2717,25 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub switch (opCode) { case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); break; case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); break; case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); break; case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); break; case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); break; case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); break; case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); break; default: continue; diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index be90a22ae..48c7a7933 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -122,35 +122,32 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha /////////////////////////////////////////////////////////////////// template -void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { +void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { if (isS()) throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!"); - if(target == nullptr) - target = this; - - if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1))) + if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !"); if (direction == 'u') - lower = -target->sizeAt(-2); + lower = -target.sizeAt(-2); else if (direction == 'l') - upper = target->sizeAt(-1); + upper = target.sizeAt(-1); const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (target->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(decltype(*target->getShapeInfo())) * target->rankOf() + 128; + const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(decltype(*target.getShapeInfo())) * target.rankOf() + 128; PointersManager manager(getContext(), "NDArray::fillAsTriangular"); - NDArray::prepareSpecialUse({target}, {this}); - fillAsTriangularCuda<<getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target->getPlatformBuffer(), target->getPlatformShapeInfo(), static_cast(val), lower, upper); - NDArray::registerSpecialUse({target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + fillAsTriangularCuda<<getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target.getPlatformBuffer(), target.getPlatformShapeInfo(), static_cast(val), lower, upper); + NDArray::registerSpecialUse({&target}, {this}); manager.synchronize(); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template @@ -457,21 +454,21 @@ BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid ////////////////////////////////////////////////////////////////////////// // create new array by repeating it the number of times given by repeats -NDArray* NDArray::repeat(const int axis, const std::vector& repeats) const { +NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { - auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output->rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector& repeats)"); const int* reps = reinterpret_cast(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); - prepareSpecialUse({output}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); - prepareSpecialUse({output}, {this}); + prepareSpecialUse({&output}, {this}); + BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); + prepareSpecialUse({&output}, {this}); manager.synchronize(); diff --git a/libnd4j/blas/cuda/NDArrayLambda.hpp b/libnd4j/blas/cuda/NDArrayLambda.hpp index c27476bfb..15028dfaa 100644 --- a/libnd4j/blas/cuda/NDArrayLambda.hpp +++ b/libnd4j/blas/cuda/NDArrayLambda.hpp @@ -247,73 +247,73 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyLambda(Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyLambda(Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType()) + if (dtype != target.dataType()) throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); - prepareSpecialUse({result}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this}); + //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyPairwiseLambda(const NDArray* other, Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType() || dtype != other->dataType()) + if (dtype != target.dataType() || dtype != other.dataType()) throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); + //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); - prepareSpecialUse({result}, {this, other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this, other}); + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedLambda(Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyIndexedLambda(Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType()) + if (dtype != target.dataType()) throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); - prepareSpecialUse({result}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this}); + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedPairwiseLambda(NDArray* other, Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType() || dtype != other->dataType()) + if (dtype != target.dataType() || dtype != other.dataType()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); - prepareSpecialUse({result}, {this, other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this, other}); + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, Lambda func, NDArray* target) { - auto result = target == nullptr ? this : target; +void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) { + auto dtype = this->dataType(); - if (dtype != result->dataType() || dtype != second->dataType() || dtype != third->dataType()) + if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType()) throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); - prepareSpecialUse({result}, {this, second, third}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second->specialBuffer(), second->specialShapeInfo(), third->specialBuffer(), third->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({result}, {this, second, third}); + prepareSpecialUse({&target}, {this, &second, &third}); + BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &second, &third}); } diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 4e879d247..7561e96cc 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -91,6 +91,10 @@ namespace nd4j { template FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo); + + template + // struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; + struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; }; diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index 75df72e70..cb1461226 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -44,7 +44,7 @@ namespace nd4j { } NDArray* NDArrayList::read(int idx) { - return readRaw(idx)->dup(); + return new NDArray(readRaw(idx)->dup()); } nd4j::DataType NDArrayList::dataType() { @@ -114,7 +114,7 @@ namespace nd4j { } else return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - + //_elements++; // storing reference @@ -136,11 +136,10 @@ namespace nd4j { std::vector args({axis}); auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); auto result = array->allTensorsAlongDimension(newAxis); - for (int e = 0; e < result->size(); e++) { - auto chunk = result->at(e);//->dup(array->ordering()); - write(e, chunk->dup(array->ordering())); + for (int e = 0; e < result.size(); e++) { + auto chunk = result.at(e);//->dup(array->ordering()); + write(e, new NDArray(chunk->dup(array->ordering()))); } - delete result; } NDArray* NDArrayList::stack() { @@ -161,7 +160,7 @@ namespace nd4j { auto result = op.execute(inputs, {}, {}, {}); - auto array = result->at(0)->dup(); + auto array = new NDArray(result->at(0)->dup()); delete result; @@ -214,13 +213,11 @@ namespace nd4j { auto tads = array->allTensorsAlongDimension(axis); int indicesSize = indices.size(); - if (tads->size() != indicesSize) + if (tads.size() != indicesSize) throw std::runtime_error("Number of TADs should match number of indices"); for (int e = 0; e < indicesSize; e++) - tads->at(e)->assign(_chunks[indices[e]]); - - delete tads; + tads.at(e)->assign(_chunks[indices[e]]); return array; } @@ -234,7 +231,7 @@ namespace nd4j { list->_elements.store(_elements.load()); for (auto const& v : _chunks) { - list->_chunks[v.first] = v.second->dup(); + list->_chunks[v.first] = new NDArray(v.second->dup()); } return list; diff --git a/libnd4j/include/graph/execution/impl/LogicConditional.cpp b/libnd4j/include/graph/execution/impl/LogicConditional.cpp index 62a533ee7..fb1f0fa1e 100644 --- a/libnd4j/include/graph/execution/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/execution/impl/LogicConditional.cpp @@ -48,7 +48,7 @@ namespace nd4j { } else { // FIXME: in some cases it's possible to have no NDArray if (inputVar->hasNDArray()) - innerVar->setNDArray(inputVar->getNDArray()->dup()); + innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); } } diff --git a/libnd4j/include/graph/execution/impl/LogicWhile.cpp b/libnd4j/include/graph/execution/impl/LogicWhile.cpp index bdabdc6bc..147c35248 100644 --- a/libnd4j/include/graph/execution/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/execution/impl/LogicWhile.cpp @@ -56,7 +56,7 @@ namespace nd4j { } else { // FIXME: in some cases it's possible to have no NDArray if (inputVar->hasNDArray()) - innerVar->setNDArray(inputVar->getNDArray()->dup()); + innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); } } diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index d77bded2e..5b8f00b25 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -40,7 +40,7 @@ namespace nd4j { result->setIndex(this->_index); if (this->_ndarray != nullptr) - result->setNDArray(this->_ndarray->template asT()); + result->setNDArray(new NDArray(this->_ndarray->template asT())); // FIXME: add support for ArrayList if (this->_list != nullptr) { @@ -61,7 +61,7 @@ namespace nd4j { result->_index = this->_index; if (this->_ndarray != nullptr) - result->_ndarray = this->_ndarray->dup(this->_ndarray->ordering()); + result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering())); if (this->_list != nullptr) result->_list = this->_list->clone(); diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index 67ca25e07..d24c31b84 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -93,7 +93,7 @@ namespace nd4j { } OpBenchmark* clone() override { - return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : _x->dup() , _y == nullptr ? _y : _y->dup(), _z == nullptr ? _z : _z->dup()); + return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : new NDArray(_x->dup()) , _y == nullptr ? _y : new NDArray(_y->dup()), _z == nullptr ? _z : new NDArray(_z->dup())); } }; } diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 0f495bf96..189143f03 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -230,17 +230,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con bool cNcont = N == 1 || C->strideAt(1) == 1; if(!aMcont && !aKcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); aMcont = true; } if(!bKcont && !bNcont) { - pB = B->dup('f'); + pB = new NDArray(B->dup('f')); toDelete.push_back(pB); bKcont = true; } if(!cMcont && !cNcont) { - pC = C->dup('f'); + pC = new NDArray(C->dup('f')); toDelete.push_back(pC); cMcont = true; } @@ -332,7 +332,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* bool aNcont = N == 1 || A->strideAt(1) == 1; if(!aMcont && !aNcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); aMcont = true; } const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; diff --git a/libnd4j/include/helpers/cpu/householder.cpp b/libnd4j/include/helpers/cpu/householder.cpp index 7fa82de8d..024695583 100644 --- a/libnd4j/include/helpers/cpu/householder.cpp +++ b/libnd4j/include/helpers/cpu/householder.cpp @@ -60,11 +60,10 @@ NDArray Householder::evalHHmatrix(const NDArray& x) { w.p(Nd4jLong(0), 1.f); wT.assign(&w); - auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); + NDArray identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); identity.setIdentity(); // identity matrix return identity - mmul(w, wT) * coeff; - } ////////////////////////////////////////////////////////////////////////// @@ -95,9 +94,9 @@ void Householder::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, coeff = -u0 / normX; if(x.isRowVector()) - tail.assign(x({0,0, 1,-1}) / u0); + tail.assign(static_cast(x({0,0, 1,-1})) / u0); else - tail.assign(x({1,-1, 0,0,}) / u0); + tail.assign(static_cast(x({1,-1, 0,0,})) / u0); } } diff --git a/libnd4j/include/helpers/cpu/jacobiSVD.cpp b/libnd4j/include/helpers/cpu/jacobiSVD.cpp index b8a51195e..4ba2bfe0a 100644 --- a/libnd4j/include/helpers/cpu/jacobiSVD.cpp +++ b/libnd4j/include/helpers/cpu/jacobiSVD.cpp @@ -269,7 +269,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { HHcolPivQR qr(matrix / scale); _m.assign(qr._qr({0,_cols, 0,_cols})); - _m.fillAsTriangular(0., 0, 0, 'l'); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); @@ -288,7 +288,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { auto matrixT = matrix.transpose(); HHcolPivQR qr(matrixT / scale); _m.assign(qr._qr({0,_rows, 0,_rows})); - _m.fillAsTriangular(0., 0, 0, 'l'); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); _m.transposei(); HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here ! @@ -305,7 +305,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { } else { - _m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); + _m.assign(static_cast(matrix({0,_diagSize, 0,_diagSize})) / scale); if(_calcU) _u.setIdentity(); @@ -366,7 +366,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { _s.p(i, math::nd4j_abs(_m.e(i,i))); if(_calcU && _m.e(i,i) < (T)0.) { auto temp = _u({0,0, i,i+1}, true); - temp.applyTransform(transform::Neg, &temp, nullptr); + temp.applyTransform(transform::Neg, temp, nullptr); } } diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 38d3b9ff4..4bf2be639 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -223,26 +223,26 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const T almostZero = DataTypeUtils::min(); T maxElem; if(len == 1) - maxElem = math::nd4j_abs(diagInterval->template e(0)); + maxElem = math::nd4j_abs(diagInterval.template e(0)); else - maxElem = (*diagInterval)({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); + maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - if(diagInterval->template e(0) < epsBig) - diagInterval->p(Nd4jLong(0), epsBig); + if(diagInterval.template e(0) < epsBig) + diagInterval.p(Nd4jLong(0), epsBig); for(int i=1; i < len; ++i) if(math::nd4j_abs(colVec0->template e(i)) < eps) colVec0->p(i, 0.f); for(int i=1; i < len; i++) - if(diagInterval->template e(i) < epsBig) { + if(diagInterval.template e(i) < epsBig) { deflation1(col1, shift, i, len); for(int i = 0; i < len; ++i) - diagInterval->p(i, _m.e(col1+shift+i,col1+shift+i)); + diagInterval.p(i, _m.e(col1+shift+i,col1+shift+i)); } { @@ -261,7 +261,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh int p = 1; for(int i=1; i(diagInterval->template e(i)) < almostZero) + if(math::nd4j_abs(diagInterval.template e(i)) < almostZero) permut[p++] = i; int k = 1, m = ind+1; @@ -271,7 +271,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh permut[p] = m++; else if(m >= len) permut[p] = k++; - else if(diagInterval->template e(k) < diagInterval->template e(m)) + else if(diagInterval.template e(k) < diagInterval.template e(m)) permut[p] = m++; else permut[p] = k++; @@ -281,7 +281,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if(totDefl) { for(int i=1; i(diagInterval->template e(ki)) < almostZero || diagInterval->template e(0) < diagInterval->template e(ki)) + if(math::nd4j_abs(diagInterval.template e(ki)) < almostZero || diagInterval.template e(0) < diagInterval.template e(ki)) permut[i-1] = permut[i]; else { permut[i-1] = 0; @@ -303,10 +303,10 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const int ki = permut[len - (totDefl ? i+1 : i)]; const int jac = tCol[ki]; - T _e0 = diagInterval->template e(jac); + T _e0 = diagInterval.template e(jac); //math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); - diagInterval->p(jac, diagInterval->template e(i)); - diagInterval->p(i, _e0); + diagInterval.p(jac, diagInterval.template e(i)); + diagInterval.p(i, _e0); if(i!=0 && jac!=0) { _e0 = colVec0->template e(jac); @@ -315,9 +315,8 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh colVec0->p(i, _e0); } - NDArray* temp1 = nullptr, *temp2 = nullptr; if (_calcU) { - auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); + auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); auto temp3 = temp1; temp1.assign(temp2); @@ -352,12 +351,12 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh { int i = len-1; - while(i > 0 && (math::nd4j_abs(diagInterval->template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) + while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) --i; for(; i > 1; --i) { - if( (diagInterval->template e(i) - diagInterval->template e(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval->template e(i) - diagInterval->template e(i-1)) >= epsBig) + if( (diagInterval.template e(i) - diagInterval.template e(i-1)) < DataTypeUtils::eps()*maxElem ) { + if (math::nd4j_abs(diagInterval.template e(i) - diagInterval.template e(i-1)) >= epsBig) throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); } @@ -365,7 +364,6 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh } delete colVec0; - delete diagInterval; } @@ -609,9 +607,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA const T almostZero = DataTypeUtils::min(); auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c'); - auto diag = *diagP; - delete diagP; + auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); diag.p(Nd4jLong(0), T(0)); singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); @@ -730,8 +726,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); temp.assign(0.); auto diag = _m.diagonal('c'); - (*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); - delete diag; + diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); return; } @@ -762,11 +757,6 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif f.assign(_u({0,1, col1+k+1, col1+n}, true)); } - // UofSVD.printIndexedBuffer(); - // VofSVD.printIndexedBuffer(); - // singVals.printIndexedBuffer(); - // printf("!! \n"); - if (_calcV) _v.p(row1W+k, col1W, 1.f); @@ -789,14 +779,10 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif temp.assign(_u({col1, col1+k+1, i, i+1}, true)); } - auto temp1 = _u({col1,col1+k+1, col1,col1+1}, true); - temp1.assign(q1 * c0); - auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true); - temp2.assign(q1 * (-s0)); - auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true); - temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0); - auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true); - temp4 *= c0; + _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); + _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); + _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0); + _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0; } else { @@ -844,8 +830,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); blockM = 0.f; auto diag = blockM.diagonal('c'); - diag->assign(singVals); - delete diag; + diag.assign(singVals); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index d191c7803..40bb9453f 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -285,17 +285,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou bool cNcont = N == 1 || C->strideAt(1) == 1; if(!aMcont && !aKcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); aMcont = true; } if(!bKcont && !bNcont) { - pB = B->dup('f'); + pB = new NDArray(B->dup('f')); toDelete.push_back(pB); bKcont = true; } if(!cMcont) { - pC = C->dup('f'); + pC = new NDArray(C->dup('f')); toDelete.push_back(pC); cMcont = true; } @@ -418,7 +418,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* bool aNcont = N == 1 || A->strideAt(1) == 1; if(!aMcont && !aNcont) { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); aMcont = true; } @@ -866,12 +866,12 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, bool cNcont = N == 1 || C->strideAt(-1) == 1; if(!aMcont && !aKcont) { - pA = A->dup('c'); + pA = new NDArray(A->dup('c')); toDelete.push_back(pA); aKcont = true; } if(!bKcont && !bNcont) { - pB = B->dup('c'); + pB = new NDArray(B->dup('c')); toDelete.push_back(pB); bNcont = true; } diff --git a/libnd4j/include/loops/impl/type_conversions.cpp b/libnd4j/include/loops/impl/type_conversions.cpp index 5a4a9db41..b12ff5796 100644 --- a/libnd4j/include/loops/impl/type_conversions.cpp +++ b/libnd4j/include/loops/impl/type_conversions.cpp @@ -82,7 +82,7 @@ namespace nd4j { // now we actually apply quantization auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - rz[e] = static_cast(nd4j::math::nd4j_round(1.0f * x[e] / nd4j::math::nd4j_max(amax, amin) * max_byte)); + rz[e] = static_cast(nd4j::math::nd4j_round( 1.0f * static_cast(x[e]) / nd4j::math::nd4j_max(amax, amin) * max_byte)); } }; @@ -180,7 +180,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write) for (auto e = start; e < stop; e += increment) { int el = x[e]; int ael = nd4j::math::nd4j_abs(el) - 1; - z[ael] += el > 0 ? threshold : -threshold; + z[ael] += el > 0 ? static_cast(threshold) : static_cast(-threshold); } }; diff --git a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/activations/crelu.cpp index 42b171226..8ce3cbf75 100644 --- a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/crelu.cpp @@ -32,21 +32,19 @@ namespace nd4j { REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); auto tmp = x->dup(); - tmp->applyTransform(nd4j::transform::Neg, nullptr, nullptr); + tmp.applyTransform(nd4j::transform::Neg, tmp); auto z = OUTPUT_VARIABLE(0); - helpers::concat(block.launchContext(), {x, tmp}, *z, x->rankOf()-1); + helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf()-1); // NDArrayFactory::concat({x, tmp}, -1, z); // TODO: make this configurable? double threshold = 0.0; - z->applyScalar(nd4j::scalar::RELU, threshold); + z->applyScalar(nd4j::scalar::RELU, threshold, *z); STORE_RESULT(z); - delete tmp; - return Status::OK(); } @@ -61,7 +59,7 @@ namespace nd4j { std::vector shape; for (int e = 0; e < shape::rank(inShape); e++) shape.emplace_back(shape::shapeOf(inShape)[e]); - + shape[shape.size()-1] *= 2; auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); @@ -94,7 +92,7 @@ namespace nd4j { auto pos = dec->at(0); auto neg = dec->at(1); - pos->applyPairwiseTransform(nd4j::pairwise::Subtract, neg, epsilon, nullptr); + pos->applyPairwiseTransform(nd4j::pairwise::Subtract, *neg, *epsilon); delete tmpResult; delete dec; diff --git a/libnd4j/include/ops/declarable/generic/activations/cube.cpp b/libnd4j/include/ops/declarable/generic/activations/cube.cpp index 075da4b00..75a33ab79 100644 --- a/libnd4j/include/ops/declarable/generic/activations/cube.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/cube.cpp @@ -31,9 +31,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::Cube, output, nullptr); + input->applyTransform(nd4j::transform::Cube, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/elu.cpp b/libnd4j/include/ops/declarable/generic/activations/elu.cpp index 03670ddab..85becd858 100644 --- a/libnd4j/include/ops/declarable/generic/activations/elu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/elu.cpp @@ -32,7 +32,7 @@ namespace nd4j { const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; - input->applyScalar(nd4j::scalar::ELU, alpha, output); + input->applyScalar(nd4j::scalar::ELU, alpha, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp b/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp index 40a98575a..d8b937a0a 100644 --- a/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::HardSigmoid, output, nullptr); + input->applyTransform(nd4j::transform::HardSigmoid, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp b/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp index 287dcc113..a4d9fe4e6 100644 --- a/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::HardTanh, output, nullptr); + input->applyTransform(nd4j::transform::HardTanh, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/identity.cpp b/libnd4j/include/ops/declarable/generic/activations/identity.cpp index f65448d92..5ae5b0690 100644 --- a/libnd4j/include/ops/declarable/generic/activations/identity.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/identity.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto z = this->getZ(block); // just for lulz - first->applyTransform(nd4j::transform::Identity, z, nullptr); + first->applyTransform(nd4j::transform::Identity, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp b/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp index 0bb47e4b4..b96ab9a3f 100644 --- a/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp @@ -33,7 +33,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(i); auto z = OUTPUT_VARIABLE(i); - x->applyTransform(transform::Identity, z, nullptr); + x->applyTransform(transform::Identity, *z); } } diff --git a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp b/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp index ef65c4822..80404135f 100644 --- a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp @@ -31,7 +31,7 @@ namespace nd4j { float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output); + input->applyScalar(nd4j::scalar::LeakyRELU, alpha, *output); STORE_RESULT(output); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp b/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp index 7e85ab9bf..5bae4d2dc 100644 --- a/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::RationalTanh, output, nullptr); + input->applyTransform(nd4j::transform::RationalTanh, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp b/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp index 69d5faa2a..40738c343 100644 --- a/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp @@ -30,9 +30,9 @@ namespace nd4j { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::RectifiedTanh, output, nullptr); + input->applyTransform(nd4j::transform::RectifiedTanh, *output); STORE_RESULT(output); - + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/relu.cpp b/libnd4j/include/ops/declarable/generic/activations/relu.cpp index 3b556ef1f..2c8b978ff 100644 --- a/libnd4j/include/ops/declarable/generic/activations/relu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/relu.cpp @@ -32,7 +32,7 @@ namespace nd4j { auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; - first->applyScalar(nd4j::scalar::RELU, scalar, z); + first->applyScalar(nd4j::scalar::RELU, scalar, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/relu6.cpp b/libnd4j/include/ops/declarable/generic/activations/relu6.cpp index a6861b3f7..cf12d1592 100644 --- a/libnd4j/include/ops/declarable/generic/activations/relu6.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/relu6.cpp @@ -33,8 +33,8 @@ CONFIGURABLE_OP_IMPL(relu6, 1, 1, true, 1, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), output); - + input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), *output); + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/selu.cpp b/libnd4j/include/ops/declarable/generic/activations/selu.cpp index 20ac42db2..ca16f6832 100644 --- a/libnd4j/include/ops/declarable/generic/activations/selu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/selu.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::SELU, z, nullptr); + first->applyTransform(nd4j::transform::SELU, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp b/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp index d6f341298..fb8e507a7 100644 --- a/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::Sigmoid, z, nullptr); + first->applyTransform(nd4j::transform::Sigmoid, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/softplus.cpp b/libnd4j/include/ops/declarable/generic/activations/softplus.cpp index 7b3ba74f2..bd538ab71 100644 --- a/libnd4j/include/ops/declarable/generic/activations/softplus.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/softplus.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::SoftPlus, z, nullptr); + first->applyTransform(nd4j::transform::SoftPlus, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/softsign.cpp b/libnd4j/include/ops/declarable/generic/activations/softsign.cpp index 50ce3a817..99e52ab68 100644 --- a/libnd4j/include/ops/declarable/generic/activations/softsign.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/softsign.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::SoftSign, z, nullptr); + first->applyTransform(nd4j::transform::SoftSign, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/activations/tanh.cpp b/libnd4j/include/ops/declarable/generic/activations/tanh.cpp index b27d07806..5677da728 100644 --- a/libnd4j/include/ops/declarable/generic/activations/tanh.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/tanh.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(nd4j::transform::Tanh, z, nullptr); + first->applyTransform(nd4j::transform::Tanh, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp index 52d01429f..6eb3728ed 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp index b8469d83a..4683e3f3e 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp index f7f3f479a..1d79a84f3 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp index 89d380d02..7a2c61c95 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp index f18314910..0a1c3d5c8 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp index 36b0defd0..0543cc72d 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index ab4ed9880..4f0fec82d 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), *y, *z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp index 1b949eb35..65d20589f 100644 --- a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp @@ -37,14 +37,14 @@ namespace nd4j { if (block.width() > 2) { auto alpha = INPUT_VARIABLE(2); - REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); + REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); } else if (block.getTArguments()->size() > 0) { a = T_ARG(0); } ExtraArguments arguments({a}); - y->applyPairwiseTransform(pairwise::Axpy, x, z, &arguments); + y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/blas/svd.cpp b/libnd4j/include/ops/declarable/generic/blas/svd.cpp index d62c621dd..8db2c2ff3 100644 --- a/libnd4j/include/ops/declarable/generic/blas/svd.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/svd.cpp @@ -33,8 +33,12 @@ CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) { const int rank = x->rankOf(); REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); - const bool fullUV = (bool)INT_ARG(0); + bool fullUV = (bool)INT_ARG(0); const bool calcUV = (bool)INT_ARG(1); + + if(calcUV == false) + fullUV = false; + const int switchNum = INT_ARG(2); // #ifndef __CUDABLAS__ diff --git a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp index 651e21aab..83cbc9004 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Not, z, nullptr); + x->applyTransform(transform::Not, *z); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/boolean/select.cpp b/libnd4j/include/ops/declarable/generic/boolean/select.cpp index 56b2c3238..92cb5e421 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/select.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/select.cpp @@ -70,17 +70,13 @@ namespace nd4j { auto tadsY = y->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims); - for (int e = 0; e < tadsX->size(); e++) { + for (int e = 0; e < tadsX.size(); e++) { if (!cond->e(e)) { - tadsZ->at(e)->assign(tadsY->at(e)); + tadsZ.at(e)->assign(tadsY.at(e)); } else { - tadsZ->at(e)->assign(tadsX->at(e)); + tadsZ.at(e)->assign(tadsX.at(e)); } } - - delete tadsX; - delete tadsY; - delete tadsZ; } } diff --git a/libnd4j/include/ops/declarable/generic/boolean/where.cpp b/libnd4j/include/ops/declarable/generic/boolean/where.cpp index b5800d3d6..6aa646cb6 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where.cpp @@ -59,17 +59,13 @@ namespace nd4j { auto tadsY = y->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims); - for (int e = 0; e < tadsX->size(); e++) { + for (int e = 0; e < tadsX.size(); e++) { if (!condition->e(e)) { - tadsZ->at(e)->assign(tadsY->at(e)); + tadsZ.at(e)->assign(tadsY.at(e)); } else { - tadsZ->at(e)->assign(tadsX->at(e)); + tadsZ.at(e)->assign(tadsX.at(e)); } } - - delete tadsX; - delete tadsY; - delete tadsZ; } } else { // in this case we return 2D matrix, which basically contains coordinates fo true diff --git a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp index 19a9a0ce9..c06ef07d1 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp @@ -89,16 +89,12 @@ namespace nd4j { auto tadsY = y->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims); - for (int e = 0; e < tadsX->size(); e++) { + for (int e = 0; e < tadsX.size(); e++) { if (!condition->e(e)) - tadsZ->at(e)->assign(tadsY->at(e)); + tadsZ.at(e)->assign(tadsY.at(e)); else - tadsZ->at(e)->assign(tadsX->at(e)); + tadsZ.at(e)->assign(tadsX.at(e)); } - - delete tadsX; - delete tadsY; - delete tadsZ; } } else { // in this case we return 2D matrix, which basically contains coordinates fo true diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp index 7d7e6f965..415a2c37a 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - + BROADCAST_CHECK_EMPTY(x,y,z); auto tZ = BroadcastHelper::broadcastApply(nd4j::BroadcastOpsTuple::Add(), x, y, z); @@ -82,14 +82,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(epsNext); if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(epsNext); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp index f11e18be6..24b673a8c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp @@ -39,7 +39,7 @@ namespace nd4j { else if (tZ != z) { OVERWRITE_RESULT(tZ); } - + return ND4J_STATUS_OK; } DECLARE_SYN(set, assign); @@ -80,7 +80,6 @@ namespace nd4j { if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(epsNext); } @@ -98,7 +97,7 @@ namespace nd4j { Nd4jLong *shapeE; Nd4jLong *shapeG; - + COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp index 8b894ac6d..32a7d7d65 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp @@ -28,7 +28,7 @@ namespace nd4j { namespace ops { BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { - + auto y = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); @@ -36,8 +36,8 @@ BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { BROADCAST_CHECK_EMPTY(x,y,z); // auto tZ = BroadcastHelper::template broadcastApply>(y, x, z); - x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), y, z, true); - + x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), *y, *z, true); + // if (tZ == nullptr) // return ND4J_STATUS_KERNEL_FAILURE; // else if (tZ != z) { diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp index 84d739ee2..1811781f1 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp @@ -81,7 +81,7 @@ namespace nd4j { // Y gradient //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, nullptr, nullptr); + gradY->applyTransform(transform::Neg, *gradY); } else if (y->isScalar()) { // scalar case @@ -91,17 +91,17 @@ namespace nd4j { //tmpX.printBuffer("SumX"); //tmp.printBuffer("Sum Eps"); gradY->assign(tmp * tmpX / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, nullptr, nullptr); + gradY->applyTransform(transform::Neg, *gradY); - //epsNext->applyLambda(lambdaS, gradX); - epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); + //epsNext->applyLambda(lambdaS, *gradX); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); } else { // broadcast case auto preX = *epsNext / *y; NDArray negX(*x); - x->applyTransform(transform::Neg, &negX); + x->applyTransform(transform::Neg, negX); auto preY = *epsNext * negX / ((*y) * (*y)); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); @@ -110,14 +110,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index ea60c2f21..d442d89e7 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -69,7 +69,7 @@ namespace nd4j { std::unique_ptr tmpResult(op.execute({x, y}, {}, {}, {})); if (gradY->rankOf() == gradX->rankOf()) - epsNext->applyPairwiseTransform(pairwise::Multiply, tmpResult->at(0), gradY, nullptr); + epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY); else // epsNext is greater than gradY { std::vector dims(epsNext->rankOf() * 2); @@ -78,7 +78,7 @@ namespace nd4j { dims[d * 2 + 1] = 1; } auto tempIn((*tmpResult->at(0))(dims)); - (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, &tempIn, gradY, nullptr); + (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp index d2a9f6260..d50ffacaa 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp @@ -79,42 +79,42 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { const Nd4jLong yLen = y->lengthOf(); if(x->isScalar() && y->isScalar()) { // both are scalars - y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); - x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); //dLdx->assign((*y) * (*dLdz)); //dLdy->assign((*x) * (*dLdz)); } else if(x->isScalar()) { // x is scalar and y is not dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, x, dLdy, nullptr); + dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy); //dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); } else if(y->isScalar()) { // y is scalar and x is not dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, y, dLdx); - } + dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx); + } else if(x->isSameShape(y)) { - x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); - y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); } else if (x->isSameShape(dLdz)) { - + auto yTiled = NDArray(dLdz, false, block.launchContext()); y->tile(yTiled); std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); - - dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); - yTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); - } + + dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); + yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + } else if (y->isSameShape(dLdz)) { auto xTiled = NDArray(dLdz, false, block.launchContext()); x->tile(xTiled); std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); - xTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); + + dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); + xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); } else { @@ -124,16 +124,16 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { y->tile(yTiled); std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); - dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); + + dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); + dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); } return Status::OK(); } DECLARE_SHAPE_FN(multiply_bp) { - + auto xShapeInfo = inputShape->at(0); auto yShapeInfo = inputShape->at(1); @@ -181,8 +181,8 @@ DECLARE_SHAPE_FN(multiply_bp) { T tmpX = x->template reduceNumber>(); gradY->assign(tmpX); - - epsNext->applyLambda(lambdaS, gradX); + + epsNext->applyLambda(lambdaS, *gradX); } else { // broadcast case @@ -201,7 +201,7 @@ DECLARE_SHAPE_FN(multiply_bp) { auto sum = preX->template reduceAlongDimension>(axisX); gradX->assign(sum); delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp index 7b4e374d5..3e7445cf0 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp @@ -71,7 +71,7 @@ namespace nd4j { // X gradient //epsNext->applyPairwiseLambda(y, lambdaX, gradX); - epsNext->applyPairwiseTransform(pairwise::Divide, y, gradX, nullptr); + epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX); // Y gradient //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); @@ -84,16 +84,16 @@ namespace nd4j { auto tmp = epsNext->reduceNumber(reduce::Sum); auto tmpX = x->reduceNumber(reduce::Sum); gradY->assign(tmp * -tmpX / ((*y) * (*y))); - + //epsNext->applyLambda(lambdaS, gradX); - epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); } else { // broadcast case auto preX = *epsNext / *y; NDArray negX(*x); - x->applyTransform(transform::Neg, &negX); + x->applyTransform(transform::Neg, negX); auto preY = *epsNext * negX / ((*y) * (*y)); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); @@ -102,14 +102,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp index 6abe8ff9c..04c4c926e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp @@ -34,7 +34,7 @@ namespace nd4j { BROADCAST_CHECK_EMPTY(x,y,z); REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!"); - x->applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true); return Status::OK(); } @@ -67,7 +67,7 @@ namespace nd4j { // X gradient //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, nullptr, nullptr); + gradX->applyTransform(transform::Neg, *gradX); // Y gradient //epsNext->applyPairwiseLambda(x, lambdaY, gradY); gradY->assign((*epsNext) / (*x)); @@ -78,14 +78,14 @@ namespace nd4j { gradY->assign(tmp / tmpX); gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, nullptr, nullptr); + gradX->applyTransform(transform::Neg, *gradX); } else { // broadcast case auto preY = (*epsNext) / (*x); auto preX = *epsNext * (*y) / ((*x) * (*x)); - preX.applyTransform(transform::Neg, nullptr, nullptr); + preX.applyTransform(transform::Neg, preX); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); @@ -93,14 +93,12 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp index af282fe7c..dbb14c78b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp @@ -61,13 +61,13 @@ namespace nd4j { if (x->isSameShape(y)) { // PWT case case - epsNext->applyTransform(transform::Neg, gradX, nullptr); + epsNext->applyTransform(transform::Neg, *gradX); gradY->assign(epsNext); } else if (y->isScalar()) { // scalar case auto tmp = epsNext->reduceNumber(reduce::Sum); gradY->assign(tmp); - epsNext->applyTransform(transform::Neg, gradX, nullptr); + epsNext->applyTransform(transform::Neg, *gradX); } else { // broadcastable auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); @@ -75,20 +75,18 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); - sum->applyTransform(transform::Neg, gradX); - delete sum; + sum.applyTransform(transform::Neg, *gradX); } else { - epsNext->applyTransform(transform::Neg, gradX, nullptr); + epsNext->applyTransform(transform::Neg, *gradX); } if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else { gradY->assign(epsNext); } - } + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp index 280a09857..ae9c93d4d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp @@ -87,7 +87,7 @@ namespace nd4j { // scalar case auto tmpX = x->reduceNumber(reduce::Sum); gradY->assign(tmpX); - + //epsNext->applyPairwiseLambda(x, lambdaS, gradX); gradX->assign((*epsNext) * ts * ((*x) - (*y))); } else { @@ -98,37 +98,31 @@ namespace nd4j { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); //epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); //epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); auto resX = (*epsNext) * ts * ((*x) - (*y)); - preX->assign(resX); + preX.assign(resX); auto resY = (*epsNext) * ts * ((*y) - (*x)); - preY->assign(resY); + preY.assign(resY); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp index 76f2d6830..40bbb8559 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp @@ -62,7 +62,7 @@ namespace nd4j { if (x->isSameShape(y)) { // PWT case case - epsNext->applyTransform(transform::Neg, gradY, nullptr); + epsNext->applyTransform(transform::Neg, *gradY); gradX->assign(epsNext); } else if (y->isScalar()) { // scalar case @@ -77,18 +77,16 @@ namespace nd4j { if (axisX.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(epsNext); if (axisY.size() > 0) { auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); - sum->applyTransform(transform::Neg, gradY); - delete sum; + sum.applyTransform(transform::Neg, *gradY); } else { - epsNext->applyTransform(transform::Neg, gradY); + epsNext->applyTransform(transform::Neg, *gradY); } - } + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp index 472cb060d..5296e8844 100644 --- a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp @@ -26,7 +26,7 @@ namespace nd4j { namespace ops { /** * This operation is, basically IF statement - * + * * arg_0 is our "signal" * arg_1 is condition that will determine transition */ @@ -41,10 +41,10 @@ namespace nd4j { // but we'll ensure only one node is active, and other is disabled if (condition->e(0) == 0) { block.setBranch(0); - this->storeResult(block, 0, input->dup()); + this->storeResult(block, 0, new NDArray(input->dup())); } else { block.setBranch(1); - this->storeResult(block, 1, *input->dup()); + this->storeResult(block, 1, new NDArray(input->dup())); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h index 5e91641ca..e497be416 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h @@ -30,7 +30,7 @@ namespace nd4j { namespace ops { class BroadcastHelper { - public: + public: static FORCEINLINE NDArray* broadcastApply(nd4j::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { if(x->isEmpty() || y->isEmpty()) { @@ -42,34 +42,34 @@ namespace nd4j { std::unique_ptr ptr; if (!Environment::getInstance()->isExperimentalBuild()) { if (y->dataType() != x->dataType()) { - y = y->cast(x->dataType()); + y = new NDArray(y->cast(x->dataType())); std::unique_ptr ptr2(y); ptr.swap(ptr2); } } if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, y, z, nullptr); + x->applyPairwiseTransform(op.p, *y, *z); } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(y), z); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (x->isScalar() && !y->isScalar()) { if (z->isSameShape(y)) { if (op.s == scalar::Add || op.s == scalar::Multiply ) { - y->applyScalarArr(op.s, x, z, nullptr); + y->applyScalarArr(op.s, *x, *z); } else if (op.s == scalar::SquaredSubtract) { - y->applyScalarArr(scalar::SquaredReverseSubtract, x, z, nullptr); + y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z); } else if (op.s == scalar::Subtract) { - y->applyScalarArr(scalar::ReverseSubtract, x, z, nullptr); + y->applyScalarArr(scalar::ReverseSubtract, *x, *z); } else if (op.s == scalar::Divide) { - y->applyScalarArr(scalar::ReverseDivide, x, z, nullptr); + y->applyScalarArr(scalar::ReverseDivide, *x, *z); } else if (op.s == scalar::Pow) { - y->applyScalarArr(scalar::ReversePow, x, z, nullptr); + y->applyScalarArr(scalar::ReversePow, *x, *z); } else if (op.s == scalar::ReverseSubtract) { - y->applyScalarArr(scalar::Subtract, x, z, nullptr); + y->applyScalarArr(scalar::Subtract, *x, *z); } else if (op.s == scalar::ReverseDivide) { - y->applyScalarArr(scalar::Divide, x, z, nullptr); + y->applyScalarArr(scalar::Divide, *x, *z); } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { - y->applyScalarArr(op.s, x, z, nullptr); + y->applyScalarArr(op.s, *x, *z); } else if (op.s == scalar::CopyPws) { z->assign(y); } else { @@ -84,9 +84,9 @@ namespace nd4j { return tZ; } } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(y), z, nullptr); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, y, z, true, extraArgs); + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); return z; } else { auto sx = ShapeUtils::shapeAsString(x); @@ -107,16 +107,16 @@ namespace nd4j { } if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, y, z, nullptr); + x->applyPairwiseTransform(op.p, *y, *z); } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, y, z, true, extraArgs); + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); return z; } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(y), z); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (x->isScalar() && !y->isScalar()) { if (z->isSameShape(y)) { //z->assign(x); - x->applyPairwiseTransform(op.p, y, z, extraArgs); + x->applyPairwiseTransform(op.p, *y, *z, extraArgs); return z; } else { auto v = y->getShapeAsVector(); @@ -125,9 +125,9 @@ namespace nd4j { return tZ; } } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(y), z, nullptr); + x->applyScalarArr(op.s, const_cast(*y), *z); } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, y, z, true, extraArgs); + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); return z; } else { auto sx = ShapeUtils::shapeAsString(x); diff --git a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp index c5bc3ca6c..2d854ae0b 100644 --- a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp @@ -51,12 +51,12 @@ namespace nd4j { std::vector axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); auto tads = array->allTensorsAlongDimension( axis); - for (int e = 0; e < tads->size(); e++) { + for (int e = 0; e < tads.size(); e++) { auto idx = indices->e(e); - if (idx >= tads->size()) + if (idx >= tads.size()) return ND4J_STATUS_BAD_ARGUMENTS; - auto arr = tads->at(e)->dup(array->ordering()); + auto arr = new NDArray(tads.at(e)->dup(array->ordering())); auto res = list->write(idx, arr); if (res != ND4J_STATUS_OK) return res; @@ -65,7 +65,6 @@ namespace nd4j { if (!hasList) //OVERWRITE_RESULT(list); setupResultList(list, block); - delete tads; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/list/split_list.cpp b/libnd4j/include/ops/declarable/generic/list/split_list.cpp index f2399c9d3..5a403dd06 100644 --- a/libnd4j/include/ops/declarable/generic/list/split_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/split_list.cpp @@ -55,7 +55,7 @@ namespace nd4j { std::vector indices(2 * array->rankOf(), 0); for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { int c_size = sizes->e(e); - + REQUIRE_TRUE(c_size > 0, 0, "Slice size should have postive value, but got %i instead", c_size); REQUIRE_TRUE(cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, "Slices size should NOT be higher then number of TADs of source array. Source size: [%i]; Slice start: [%i]; Slice size: [%i]", array->sizeAt(0), cnt, c_size); @@ -63,11 +63,11 @@ namespace nd4j { indices[0] = cnt; indices[1] = cnt + c_size; cnt += c_size; - + auto subarray = (*array)(indices); - auto status = list->write(e, subarray.dup(array->ordering())); - + auto status = list->write(e, new NDArray(subarray.dup(array->ordering()))); + if (status != ND4J_STATUS_OK) return status; } diff --git a/libnd4j/include/ops/declarable/generic/list/write_list.cpp b/libnd4j/include/ops/declarable/generic/list/write_list.cpp index 8ac1935b3..c9b32234e 100644 --- a/libnd4j/include/ops/declarable/generic/list/write_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/write_list.cpp @@ -39,7 +39,7 @@ namespace nd4j { //nd4j_printf("Writing [%i]:\n", idx->e(0)); //input->printShapeInfo("input shape"); //input->printIndexedBuffer("input buffer"); - Nd4jStatus result = list->write(idx->e(0), input->dup()); + Nd4jStatus result = list->write(idx->e(0), new NDArray(input->dup())); auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); //res->printShapeInfo("Write_list 2 output shape"); @@ -52,7 +52,7 @@ namespace nd4j { auto input = INPUT_VARIABLE(1); auto idx = INT_ARG(0); - Nd4jStatus result = list->write(idx, input->dup()); + Nd4jStatus result = list->write(idx, new NDArray(input->dup())); auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); //res->printShapeInfo("Write_list 1 output shape"); diff --git a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp index d028e5af8..ba488df65 100644 --- a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp @@ -169,10 +169,10 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { NDArray E = *predictions - *labels; // dE_i/dp_i = sign(p_i - y_i) - E.applyTransform(nd4j::transform::Sign, dLdp); // dE/dp + E.applyTransform(nd4j::transform::Sign, *dLdp); // dE/dp // dE_i/dy_i = -sign(p_i - y_i) - E.applyTransform(nd4j::transform::Abs); + E.applyTransform(nd4j::transform::Abs, E); switch (reductionMode) { @@ -184,7 +184,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -210,7 +210,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -238,7 +238,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp index 0b4e3fe89..7fe75c03a 100644 --- a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); } - NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); + NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(&E); break; - + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array output->assign(E.reduceNumber(reduce::Sum)); break; @@ -79,12 +79,12 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { NDArray sum; if (weights->isScalar()) sum = *weights * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -99,9 +99,9 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { if (numOfNonZeroWeights == 0) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - + break; } } @@ -111,7 +111,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } @@ -124,7 +124,7 @@ DECLARE_TYPES(cosine_distance_loss) { ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(cosine_distance_loss) { - // labels and predictions must have the same shapes + // labels and predictions must have the same shapes auto predictionsShapeInfo = inputShape->at(0); auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); @@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { // input dimension can't be larger than labels/predictions/weights rank REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); - NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); + NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { else { if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -249,7 +249,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -284,7 +284,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeights; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index b62dffad8..8670bf9e1 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -52,7 +52,7 @@ namespace nd4j { // We first need to convert binary labels to -1/1 labels (as floats) NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); - E.applyScalar(scalar::RELU, 0.0f, &E); + E.applyScalar(scalar::RELU, 0.0f, E); // multiply E on weights E *= *weightsBroad; @@ -172,11 +172,11 @@ namespace nd4j { NDArray z = (*labels * 2.f - 1.f); NDArray E = 1.f - z * (*logits); - E.applyScalar(scalar::RELU, 0.0f, &E); + E.applyScalar(scalar::RELU, 0.0f, E); // turn E into gradient mask NDArray gradientMask(E.getShapeInfo(), block.getWorkspace()); - E.applyTransform(nd4j::transform::Sign, &gradientMask); + E.applyTransform(nd4j::transform::Sign, gradientMask); dLdp->assign(-z * gradientMask); dLdl->assign(-2.f * (*logits) * gradientMask); @@ -192,7 +192,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -220,7 +220,7 @@ namespace nd4j { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -249,7 +249,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index 3e7686a3d..e844b4126 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -46,17 +46,17 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - + // perform weights broadcasting/tile to predictions if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); auto error = *predictions - *labels; - error.applyTransform(transform::Abs); + error.applyTransform(transform::Abs, error); NDArray quadratic(error.getShapeInfo(), block.getWorkspace()); - error.applyScalar(scalar::MinPairwise, delta, &quadratic); - + error.applyScalar(scalar::MinPairwise, delta, quadratic); + NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; // multiply E on weights @@ -75,12 +75,12 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = *weights * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -104,7 +104,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } @@ -173,24 +173,24 @@ DECLARE_SHAPE_FN(huber_loss) { NDArray diff = *predictions - *labels; NDArray absDiff(diff); - absDiff.applyTransform(transform::Abs); + absDiff.applyTransform(transform::Abs, absDiff); NDArray quadratic(absDiff); - absDiff.applyScalar(scalar::MinPairwise, delta, &quadratic); + absDiff.applyScalar(scalar::MinPairwise, delta, quadratic); NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta; NDArray lteMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::LessThanOrEqual, delta, <eMask); + absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask); NDArray gtMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::GreaterThan, delta, >Mask); + absDiff.applyScalar(scalar::GreaterThan, delta, gtMask); NDArray signDiff(diff); - diff.applyTransform(transform::Sign, &signDiff); + diff.applyTransform(transform::Sign, signDiff); - auto gtMaskFloat = *gtMask.cast(diff.dataType()); - auto lteMaskFloat = *lteMask.cast(diff.dataType()); + auto gtMaskFloat = gtMask.cast(diff.dataType()); + auto lteMaskFloat = lteMask.cast(diff.dataType()); dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff); @@ -207,7 +207,7 @@ DECLARE_SHAPE_FN(huber_loss) { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -235,7 +235,7 @@ DECLARE_SHAPE_FN(huber_loss) { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -264,7 +264,7 @@ DECLARE_SHAPE_FN(huber_loss) { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index 33d5c03ec..f83947c69 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -29,11 +29,11 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { - + auto predictions = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto output = OUTPUT_VARIABLE(0); int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" @@ -48,7 +48,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - + // perform weights broadcasting/tile to predictions if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(predictions)) @@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { // multiply E on weights E *= *weightsBroad; - + switch (reductionMode) { case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(E); @@ -72,12 +72,12 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = *weights * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -101,13 +101,13 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -118,11 +118,11 @@ DECLARE_SHAPE_FN(log_loss) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); @@ -132,7 +132,7 @@ DECLARE_SHAPE_FN(log_loss) { outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); else // in this case output has the same shape as labels and predictions outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - + return SHAPELIST(outShapeInfo); } @@ -143,33 +143,33 @@ DECLARE_SHAPE_FN(log_loss) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { - + auto predictions = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" + int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients if(reductionMode == 0) reductionMode = 1; - - // FIXME: double? + + // FIXME: double? double epsilon = T_ARG(0); // input validation REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed + + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); @@ -179,24 +179,24 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) - dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp + dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) - ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, dLdl); // dE/dy + ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, *dLdl); // dE/dy NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); - + // process 3 possible reduction modes below - switch (reductionMode) { + switch (reductionMode) { case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array *dLdp *= *weightsBroad; *dLdl *= *weightsBroad; - + if(weights->isScalar()) dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -208,9 +208,9 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) { *dLdp = 0.; *dLdl = 0.; @@ -221,27 +221,27 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { NDArray temp = *weightsBroad / sum; *dLdp *= temp; *dLdl *= temp; - + if(weights->isScalar()) *dLdw = 0.; else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); + } + else + dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); } break; } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; if(weights->isScalar()) { if(weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + else + numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); if (numOfNonZeroWeights == 0) { *dLdp = 0.; @@ -254,12 +254,12 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else dLdw->assign(E / numOfNonZeroWeightsScalar); - + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; *dLdp *= temp; *dLdl *= temp; @@ -270,13 +270,13 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -287,19 +287,19 @@ DECLARE_SHAPE_FN(log_loss_grad) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 6f3d0c5dd..0d85c6e23 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -55,9 +55,9 @@ namespace ops { NDArray E(labels->getShapeInfo(), block.getWorkspace()); if (computeFullLoss) - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); else - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); // multiply E on weights @@ -176,19 +176,19 @@ namespace ops { NDArray E(labels->getShapeInfo(), block.getWorkspace()); if (computeFullLoss) { - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); NDArray rDiv(labels->getShapeInfo(), block.getWorkspace()); - labels->applyScalar(scalar::ReverseDivide, 0.5f, &rDiv); + labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); } else { - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); dLdl->assign(-(*log_predictions)); } dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); - + switch (reductionMode) { case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array @@ -200,7 +200,7 @@ namespace ops { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -228,7 +228,7 @@ namespace ops { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -257,7 +257,7 @@ namespace ops { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index 003ae815b..ef511921f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -112,10 +112,10 @@ namespace nd4j { auto n = double(labels->sizeAt(1)); auto diffs = *predictions - *labels; - auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); + auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2); + auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); @@ -240,15 +240,15 @@ namespace nd4j { auto diffs = *predictions - *labels; std::vector reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); - auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); + auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2); + auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); - auto sumPred = predictions->reduceAlongDims(reduce::Sum, reductionIdx, true); - auto sumLabel = labels->reduceAlongDims(reduce::Sum, reductionIdx, true); + auto sumPred = predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true); + auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true); dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1)))); @@ -273,7 +273,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -299,7 +299,7 @@ namespace nd4j { *dLdw = 0.; else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -327,7 +327,7 @@ namespace nd4j { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index c519ab020..f446d0bf0 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -35,8 +35,8 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { auto output = OUTPUT_VARIABLE(0); int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // inputs validation + + // inputs validation REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); @@ -45,13 +45,13 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) + if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); NDArray E(labels->getShapeInfo(), false, block.launchContext()); - predictions->applyPairwiseTransform(pairwise::SquaredSubtract, labels, &E, nullptr); + predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E); // multiply E on weights E *= (*weightsBroad); @@ -60,7 +60,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(&E); break; - + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array E.reduceNumber(reduce::Sum, *output); break; @@ -69,12 +69,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) (*output) = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -101,12 +101,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -121,7 +121,7 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) { REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); @@ -132,7 +132,7 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) { else // in this case output has the same shape as labels and predictions outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - return SHAPELIST(outShapeInfo); + return SHAPELIST(outShapeInfo); } @@ -144,11 +144,11 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { - + auto predictions = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels @@ -157,8 +157,8 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients if(reductionMode == 0) reductionMode = 1; - - // inputs validation + + // inputs validation REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); @@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) + if(!weights->isScalar() && !weights->isSameShape(predictions)) weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); NDArray diff = *predictions - *labels; @@ -178,20 +178,20 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { dLdp->assign(2. * diff); // dE/dp // dE_i/dy_i = -2 * (p_i - y_i) // dLdl->assign(-(*dLdp)); // dE/dl - + NDArray E = diff * diff; - switch (reductionMode) { - + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array *dLdp *= *weightsBroad; - + if(weights->isScalar()) dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -202,40 +202,40 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) { - *dLdp = 0.; + *dLdp = 0.; *dLdw = 0.; } else { - + *dLdp *= *weightsBroad / sum; - + if(weights->isScalar()) *dLdw = 0.; else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); + } + else + dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); } break; } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights Nd4jLong numOfNonZeroWeights = 0; if(weights->isScalar()) { if(weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + else + numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); if (numOfNonZeroWeights == 0) { - *dLdp = 0.; + *dLdp = 0.; *dLdw = 0.; } else { @@ -245,14 +245,14 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else dLdw->assign(E / numOfNonZeroWeights); - + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; + *dLdp *= temp; } break; } @@ -262,12 +262,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { if(weightsBroad != weights) delete weightsBroad; - + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -281,15 +281,15 @@ DECLARE_SHAPE_FN(mean_sqerr_loss_grad) { REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index b3f707b23..5b0075466 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -38,27 +38,27 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" auto labelsSmoothing = T_ARG(0); - // input validation + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(logits)) weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo())); - + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: auto newLabels = labels; if(labelsSmoothing != 0.) { newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, newLabels, nullptr); + newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, *newLabels); } - + NDArray E(labels, false, block.launchContext()); // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits @@ -66,12 +66,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { // multiply E on weights E *= *weightsBroad; - + switch (reductionMode) { case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. output->assign(E); break; - + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array E.reduceNumber(reduce::Sum, *output); break; @@ -80,12 +80,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) *output = 0.; - else + else output->assign(E.reduceNumber(reduce::Sum) / sum); break; } @@ -111,13 +111,13 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { delete weightsBroad; if(newLabels != labels) delete newLabels; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -128,11 +128,11 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); @@ -142,8 +142,8 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss) { outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); else // in this case output has the same shape as labels and logits outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + + return SHAPELIST(outShapeInfo); } @@ -155,12 +155,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { auto logits = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto labels = INPUT_VARIABLE(2); - + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - + NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), block.launchContext()); int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" @@ -168,27 +168,27 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { if(reductionMode == 0) reductionMode = 1; - // input validation + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // perform weights broadcasting/tile to labels if needed + // perform weights broadcasting/tile to labels if needed auto weightsBroad = weights; if(!weights->isScalar() && !weights->isSameShape(logits)) weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo())); - + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: auto newLabels = labels; if(labelsSmoothing.e(0) != 0.f) { newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e(0), newLabels, nullptr); + newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e(0), *newLabels); } - + NDArray E(labels, false, block.launchContext()); // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits @@ -196,24 +196,24 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { // dLdp = 1 - labels - 1 / (1 + exp(logits)) helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); - + // dLdl = -logits labelsSmoothing -= 1.f; dLdl->assign(*logits * labelsSmoothing); - switch (reductionMode) { + switch (reductionMode) { case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array *dLdp *= *weightsBroad; *dLdl *= *weightsBroad; - + if(weights->isScalar()) dLdw->assign(E.reduceNumber(reduce::Sum)); else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } - else + else dLdw->assign(E); break; } @@ -221,9 +221,9 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { NDArray sum; if (weights->isScalar()) sum = (*weights) * E.lengthOf(); - else + else sum = weightsBroad->reduceNumber(reduce::Sum); - + if (sum.e(0) == 0.) { *dLdp = 0.; *dLdl = 0.; @@ -234,14 +234,14 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { NDArray temp = *weightsBroad / sum; *dLdp *= temp; *dLdl *= temp; - + if(weights->isScalar()) *dLdw = 0.; else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); - } - else + std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); + } + else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum)); } break; @@ -252,8 +252,8 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { if(weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + else + numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); if (numOfNonZeroWeights == 0) { *dLdp = 0.; @@ -267,12 +267,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); else if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeightsScalar; } else dLdw->assign(E / numOfNonZeroWeightsScalar); - + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; *dLdp *= temp; *dLdl *= temp; @@ -285,13 +285,13 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { delete weightsBroad; if(newLabels != labels) delete newLabels; - + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } @@ -302,11 +302,11 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { auto weightsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(2); - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array + // check whether broadcast operation is possible for weights array REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); @@ -314,7 +314,7 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); - + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp index faabc7c18..a1a197fae 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp @@ -54,11 +54,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) { // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // num_classes = labels->sizeAt(1) - auto cLabels = labels->cast(weights->dataType()); - auto newLabels = cLabels; + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; if(labelsSmoothing != 0.) { newLabels = new NDArray(cLabels); - *newLabels = (1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1); + newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); } // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension @@ -70,9 +70,9 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) { std::vector dimensions = {-1}; - NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); + NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -217,25 +217,25 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // num_classes = labels->sizeAt(1) - auto cLabels = labels->cast(weights->dataType()); - auto newLabels = cLabels; + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; if(labelsSmoothing != 0.) { newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext()); newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); } - NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimensions, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDims(reduce::Sum, dimensions, true); + NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)).transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true); // dEdp = softmax * sum_i(lables_i) - labels - dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels); + dLdp->assign(softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - *newLabels); // dEdl = -log(softmax) dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing)); - NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); + NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); // perform weights broadcasting/tile to E if it is necessary auto weightsBroad = weights; @@ -253,12 +253,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { *dLdl *= *weights; } else { - dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); - dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); + dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdp); + dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdl); if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign(E); @@ -289,12 +289,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { else { NDArray temp = *weightsBroad / sum; - dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); - dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); + dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp); + dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl); if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); } else dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); @@ -326,12 +326,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { } else { NDArray temp = *weightsBroad / numOfNonZeroWeights; - dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); - dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); + dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp); + dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl); if(weights != weightsBroad) { std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); - E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); *dLdw /= numOfNonZeroWeights; } else diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp index b129dd483..5e88ec0e6 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp @@ -34,38 +34,38 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { auto output = OUTPUT_VARIABLE(0); const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation + + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - auto maxAlongDim = logits->reduceAlongDims(reduce::Max, {classesDim}, true); + std::vector dimension = {classesDim}; + + auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, {classesDim}, true); auto logExp = (*logits - maxAlongDim).transform(transform::Exp); - auto logSoftMax = ( logExp / logExp.reduceAlongDims(reduce::Sum, {classesDim}, true) ).transform(transform::Log); - - (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, output, dimension); - + auto logSoftMax = ( logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true) ).transform(transform::Log); + + (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, *output, dimension); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { - + auto logitsShapeInfo = inputShape->at(0); auto labelsShapeInfo = inputShape->at(1); const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : -1; std::vector dimensions = {classesDim}; - // labels and logits must have the same shapes + // labels and logits must have the same shapes REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); @@ -90,46 +90,46 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) { auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation + + // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); + std::vector dimension = {classesDim}; + + NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); // dEdp = softmax * sum_i(labels_i) - labels - dLdp->assign(softmax * labels->reduceAlongDims(reduce::Sum, dimension, true) - *labels); + dLdp->assign(softmax * labels->reduceAlongDimension(reduce::Sum, dimension, true) - *labels); + + // dEdl = -log(softmax) + (-softmax).applyTransform(transform::Log, *dLdl); - // dEdl = -log(softmax) - (-softmax).applyTransform(transform::Log, dLdl); - return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits_grad) { - + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits_grad) { - auto logitsShapeInfo = inputShape->at(0); + auto logitsShapeInfo = inputShape->at(0); auto labelsShapeInfo = inputShape->at(1); - // labels and logits must have the same shapes + // labels and logits must have the same shapes REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - + return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp index 4c2da4d0b..e7c8da123 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp @@ -50,9 +50,9 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) std::vector dimension = {-1}; - auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true); + auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true); auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); - auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log)); + auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true) ).transform(transform::Log)); helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false); @@ -117,8 +117,8 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, std::vector dimension = {-1}; - NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); + NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); // dEdp = softmax - 1 (or 0) dLdp->assign(softmax); diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 8b6bd24bc..a8cd17131 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -229,19 +229,19 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { // input - mean NDArray xMinusMean(input); // empty array with same shape as input - input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); + input->applyBroadcast(nd4j::broadcast::Subtract, axes, *mean, xMinusMean); // stdInv NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 // dvdm (use dLdM as storage for dvdm) - xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); + xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, *dLdM, excludedAxes, keepUnitiesInShape); *dLdM *= -Ninv; // g_sum - auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); + auto gSum = dLdO->reduceAlongDimension(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); // dLdB if(applyOffset) @@ -249,11 +249,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) gSum *= Ninv; - dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI); - dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv); + dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, gSum, *dLdI); + dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, stdInv, *dLdI); // dLdV <- [g*(x - m)]_sum - (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); + (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, *dLdV, excludedAxes, keepUnitiesInShape); // dLdG *dLdV *= stdInv; @@ -265,13 +265,13 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { *dLdV *= -Ninv; // -0.5f * (2 / N); // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) - xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM); - xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV); + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, *dLdM, xMinusMean); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, *dLdV, xMinusMean); // dLdI *dLdI += xMinusMean; if(applyScale) - dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma); + dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, *gamma, *dLdI); *dLdM = 0; // put zeros so far *dLdV = 0; // put zeros so far diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 98223c5b4..a4fd7f7c3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -240,7 +240,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { if(gradB) { if(gradB->rankOf() == 2) gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW + gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW if(gradB != OUTPUT_VARIABLE(2)) delete gradB; } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 3b8c51bc7..4a5bbd845 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { if(gradB) { if(gradB->rankOf() == 2) gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW if(gradB != OUTPUT_VARIABLE(2)) delete gradB; } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index 1baccbe0e..1b832ea68 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -244,7 +244,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { if(gradB) { if(gradB->rankOf() == 2) gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW if(gradB != OUTPUT_VARIABLE(2)) delete gradB; } diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index 84facc0cc..0754000a3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -42,35 +42,35 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { auto batchVar = OUTPUT_VARIABLE(2); // [iD] const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW - const bool isTraining = (bool)INT_ARG(1); + const bool isTraining = (bool)INT_ARG(1); - REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); + REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); int bS = x->sizeAt(0); // batch size - int iH, iW, iD; // input height, input width, input depth(number of channels) + int iH, iW, iD; // input height, input width, input depth(number of channels) if(dataFormat) { iD = x->sizeAt(1); iH = x->sizeAt(2); iW = x->sizeAt(3); } else { - iD = x->sizeAt(3); + iD = x->sizeAt(3); iH = x->sizeAt(1); - iW = x->sizeAt(2); - } + iW = x->sizeAt(2); + } REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); NDArray *mean(nullptr), *variance(nullptr); if(!isTraining){ - mean = INPUT_VARIABLE(3); - variance = INPUT_VARIABLE(4); + mean = INPUT_VARIABLE(3); + variance = INPUT_VARIABLE(4); REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str()); REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str()); } else { - //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); + //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); std::vector shape = {iD}; mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); @@ -78,13 +78,13 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { // FIXME: double? double epsilon; - if(block.getTArguments()->size() > 0) + if(block.getTArguments()->size() > 0) epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; - else + else epsilon = 0.001; - - const int restSize = x->lengthOf() / iD; - auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, x->dataType(), block.launchContext()); + + const int restSize = x->lengthOf() / iD; + auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); xAffected.assign(x); const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; @@ -93,28 +93,28 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { const double restSizeAdjust = (double)restSize / restSizeMinusOne; if(isTraining) { - auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); sum *= restSizeInv; mean->assign(sum); *batchMean = *mean; //delete sum; } - else + else *batchMean = 0.; - + xAffected -= *mean; - if(isTraining) { + if(isTraining) { int power = 2; - xAffected.applyScalar(scalar::Pow, power); - auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); + xAffected.applyScalar(scalar::Pow, power, xAffected); + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); sum *= restSizeInv; variance->assign(sum); *batchVar = (*variance) * restSizeAdjust; //delete sum; } - else - *batchVar = 0.; + else + *batchVar = 0.; xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); y->assign( xAffected ); @@ -136,13 +136,13 @@ DECLARE_SHAPE_FN(fused_batch_norm) { const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); - + Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr); - + COPY_SHAPE(xShapeInfo, outShapeInfo); COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); - COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); - + COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); + return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp index 6dffead8b..f5cc78e2b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp @@ -37,7 +37,7 @@ namespace ops { CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + const int rank = input->rankOf(); const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; @@ -67,8 +67,8 @@ CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) { REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); helpers::softmax(block.launchContext(), *input, *gradI, dim); - - gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true) ); + + gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true) ); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index 3f5c16c17..4e62abc60 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -31,10 +31,10 @@ namespace nd4j { REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf()); REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf()); REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1)); - REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", + REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0)); - + auto output = OUTPUT_VARIABLE(0); //T bound = (T)0.f; //nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf()); @@ -46,7 +46,7 @@ namespace nd4j { auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto xw = result->at(0); - xw->applyScalar(nd4j::scalar::RELU, scalar, output); + xw->applyScalar(nd4j::scalar::RELU, scalar, *output); return Status::OK(); } @@ -55,7 +55,7 @@ namespace nd4j { auto inShape = inputShape->at(0); auto weightsShape = inputShape->at(1); auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.getWorkspace()); - + return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp index d96f97c10..06bd6d379 100644 --- a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp @@ -38,7 +38,7 @@ namespace ops { CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + const int rank = input->rankOf(); const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; @@ -59,10 +59,10 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) { const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); - + helpers::softmax(block.launchContext(), *input, *gradI, dim); - auto sumAlongDim = (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true); + auto sumAlongDim = (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true); gradI->assign(*gradI * (*gradO - sumAlongDim)); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index 27b6a4302..65f01cf6c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -56,7 +56,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { axes[i] = i; // mean as reduction for last dimension set - auto mean = input->reduceAlongDims(reduce::Mean, axes); + auto mean = input->reduceAlongDimension(reduce::Mean, axes); // this is contrast calculation output->assign((*input - mean) * (*factor) + mean); @@ -104,13 +104,13 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { std::vector axes({1}); // dim 1 of pseudoresult // mean as reduction for last dimension set over size (dim 1) of result3D - auto mean = input3D.reduceAlongDims(reduce::Mean, axes); + auto mean = input3D.reduceAlongDimension(reduce::Mean, axes); // result as (x - mean) * factor + mean auto temp = input3D.ulike(); - input3D.applyBroadcast(broadcast::Subtract, {0, 2}, &mean, &temp, nullptr); - temp.applyScalarArr(scalar::Multiply, factor); - temp.applyBroadcast(broadcast::Add, {0, 2}, &mean, &output3D); + input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp); + temp.applyScalarArr(scalar::Multiply, *factor, temp); + temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D); output->assign(output3D); if(block.width() == 1) delete factor; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp index 3fd5e2250..10e036b61 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp @@ -44,11 +44,11 @@ namespace nd4j { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->applyIndexReduce(indexreduce::IndexMax, output, axis); + input->applyIndexReduce(indexreduce::IndexMax, *output, axis); } else { helpers::adjustAxis(input->rankOf(), axis); - input->applyIndexReduce(indexreduce::IndexMax, output, axis); + input->applyIndexReduce(indexreduce::IndexMax, *output, axis); } STORE_RESULT(output); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp index 91e9d5a41..554b7b95b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp @@ -44,11 +44,11 @@ namespace nd4j { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->applyIndexReduce(indexreduce::IndexMin, output, axis); + input->applyIndexReduce(indexreduce::IndexMin, *output, axis); } else { helpers::adjustAxis(input->rankOf(), axis); - input->applyIndexReduce(indexreduce::IndexMin, output, axis); + input->applyIndexReduce(indexreduce::IndexMin, *output, axis); } STORE_RESULT(output); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp index b43895a31..0c88a9c53 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp @@ -82,7 +82,7 @@ CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) { gradI->assign(gradO); - gradO->reduceAlongDimension(nd4j::reduce::Sum, gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); + gradO->reduceAlongDimension(nd4j::reduce::Sum, *gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp index fc928c3cd..822b4b91b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp @@ -45,7 +45,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { v = i++; } - std::unique_ptr outputView(output->allTensorsAlongDimension(dims)); + ResultSet outputView = output->allTensorsAlongDimension(dims); REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.", output->sizeAt(0), block.width() ); @@ -53,7 +53,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { Nd4jLong thisIndex = (*indeces).e(e); input = INPUT_VARIABLE(thisIndex); // lookup param - outputView->at(e)->assign(input); + outputView.at(e)->assign(input); } } else { @@ -87,7 +87,7 @@ DECLARE_SHAPE_FN(embedding_lookup) { int inRank = shape::rank(inShapeInfo); if (inputShape->size() == 2u) { int outRank = inRank; - + std::vector shapeInfo(outRank); shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements @@ -98,14 +98,14 @@ DECLARE_SHAPE_FN(embedding_lookup) { return SHAPELIST(outShapeInfo); } - - int outRank = inRank + 1; + + int outRank = inRank + 1; std::vector shapeInfo(outRank); auto indeces = INPUT_VARIABLE(block.width() - 1); shapeInfo[0] = indeces->lengthOf(); // vector - how many elements for (int e = 1; e < outRank; e++) shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - + auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp index 12b6c9e07..5e76fefec 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp @@ -49,8 +49,8 @@ namespace nd4j { } std::vector& dims = axis; - input->varianceAlongDimension(variance::SummaryStatsVariance, variances, false, axis); - input->reduceAlongDimension(reduce::Mean, means, axis, keepDims); + input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, false, axis); + input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims); return Status::OK(); } @@ -74,10 +74,10 @@ namespace nd4j { } //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - + auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); - return SHAPELIST(meanShape, varianceShape); + return SHAPELIST(meanShape, varianceShape); } DECLARE_TYPES(moments) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp index 983f18bd9..e74a28184 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp @@ -52,31 +52,31 @@ namespace nd4j { case 0: { REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); // fro - input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); } break; case 1: { // euclidean if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { - input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); } else { - input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); } } break; case 2: { // 1 - input->reduceAlongDimension(reduce::Norm1, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::Norm1, *output, dims, false, output->rankOf() == 2); } break; case 3: { - // 2 - input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); + // 2 + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); } break; case 4: { // inf-norm - input->reduceAlongDimension(reduce::NormMax, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormMax, *output, dims, false, output->rankOf() == 2); } break; default: { @@ -84,7 +84,7 @@ namespace nd4j { REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); // FIXME: p is required here //T p = T_ARG(1); - input->reduceAlongDimension(reduce::NormP, output, dims, false, output->rankOf() == 2); + input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2); } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index 23fd9a79e..15f295995 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -40,23 +40,20 @@ namespace nd4j { shift.assign(T_ARG(0)); } - means->applyScalarArr(scalar::Divide, counts, resMeans, nullptr); + means->applyScalarArr(scalar::Divide, *counts, *resMeans); - NDArray* squareMeans = resMeans->dup('c'); - NDArray* tempVariances = resVariances->dup('c'); + NDArray squareMeans = resMeans->dup('c'); + NDArray tempVariances = resVariances->dup('c'); - squareMeans->applyTransform(transform::Square, squareMeans, nullptr); - variances->applyScalarArr(scalar::Divide, counts, tempVariances, nullptr); -// tempVariances->printIndexedBuffer("varianced divided by count"); - tempVariances->applyPairwiseTransform(pairwise::Subtract, squareMeans, resVariances, nullptr); + squareMeans.applyTransform(transform::Square, squareMeans, nullptr); + variances->applyScalarArr(scalar::Divide, *counts, tempVariances); +// tempVariances.printIndexedBuffer("varianced divided by count"); + tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); if (shift.e(0) != 0) { - resMeans->applyScalarArr(scalar::Add, &shift, resMeans, nullptr); + resMeans->applyScalarArr(scalar::Add, shift, *resMeans); } - delete squareMeans; - delete tempVariances; - return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp index 0beec605a..f83994606 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp @@ -47,7 +47,7 @@ CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - input->reduceAlongDimension(reduce::Mean, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp index f1ebf91d1..6a3e7c050 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp @@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, output, biasCorrected, dimensions); + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, *output, biasCorrected, dimensions); return Status::OK(); } @@ -130,10 +130,10 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; - auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); NDArray variance(mean.getShapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, &variance, biasCorrected, dimensions); + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, variance, biasCorrected, dimensions); gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp index dbf470935..16bfdc8a9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp @@ -54,8 +54,8 @@ CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->varianceAlongDimension(variance::SummaryStatsVariance, output, biasCorrected, dimensions); + + input->varianceAlongDimension(variance::SummaryStatsVariance, *output, biasCorrected, dimensions); return Status::OK(); } @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_variance) { } REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - + for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); @@ -128,9 +128,9 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; const double factor1 = 2.0 / NminusOne; const double factor2 = 2.0 / (N * NminusOne); - - auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); - + + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); + gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here if(!keepDims) { @@ -153,13 +153,13 @@ DECLARE_SHAPE_FN(reduce_variance_bp) { } REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - + for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - + Nd4jLong* gradIshapeInfo(nullptr); COPY_SHAPE(in, gradIshapeInfo); - + return SHAPELIST(CONSTANT(gradIshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp index 0cf0e1f9e..a02b4db9b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp @@ -45,9 +45,9 @@ namespace ops { //void* whereMax = (void*)(); auto internal = (*input); internal -= maxVals; - internal.applyTransform(transform::Exp, nullptr, nullptr); - internal.reduceAlongDimension(reduce::Sum, output, axes, keepDims, false); //, (void*)&maxVals); - output->applyTransform(transform::Log, nullptr, nullptr); + internal.applyTransform(transform::Exp, internal); + internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, false); //, (void*)&maxVals); + output->applyTransform(transform::Log, *output); (*output) += maxVals; return ND4J_STATUS_OK; } @@ -56,7 +56,7 @@ namespace ops { -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) -> setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(reduce_logsumexp) { + DECLARE_SHAPE_FN(reduce_logsumexp) { const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; auto input = INPUT_VARIABLE(0); @@ -74,6 +74,6 @@ namespace ops { return SHAPELIST(outShapeInfo); } -#endif +#endif } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp index 4ab9954b0..870017e8d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp @@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) { else if (block.getTArguments()->size() > 0) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Max, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); return Status::OK(); } @@ -122,8 +122,7 @@ CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) { else { auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - delete indicesArr; + helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp index cb9b9e21b..e8b073de8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp @@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) { else if (block.getTArguments()->size() > 0) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Min, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); return Status::OK(); } @@ -89,7 +89,7 @@ DECLARE_TYPES(reduce_min) { } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_min_bp) @@ -125,8 +125,7 @@ CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) { else { auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMin, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - delete indicesArr; + helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp index 8da05c3f4..172f3df8e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Norm1, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); return Status::OK(); } @@ -85,7 +85,7 @@ DECLARE_TYPES(reduce_norm1) { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS}); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_norm1_bp) ////////////////////////////////////////////////////////////////////////// @@ -100,7 +100,7 @@ CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); auto gradI = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::Sign, gradI); + input->applyTransform(nd4j::transform::Sign, *gradI); if (gradO->lengthOf() == 1) { *gradI *= *gradO; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp index 1a7e0a911..e54518359 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp @@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Norm2, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); return Status::OK(); } @@ -124,7 +124,7 @@ CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { // *** calculations *** // - *gradI /= input->reduceAlongDims(reduce::Norm2, dimensions, true); + *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); if(!keepDims) { auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp index 902b1d699..c71310947 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp @@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_norm_max, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::NormMax, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); return Status::OK(); } @@ -87,7 +87,7 @@ DECLARE_TYPES(reduce_norm_max) { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS}); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_norm_max_bp) @@ -124,9 +124,8 @@ CUSTOM_OP_IMPL(reduce_norm_max_bp, 2, 1, false, 0, 0) { else { auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexAbsoluteMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation *gradI *= input->transform(nd4j::transform::Sign); - delete indicesArr; } return Status::OK(); @@ -139,7 +138,7 @@ DECLARE_SHAPE_FN(reduce_norm_max_bp) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); for(const auto& item : dimensions) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp index 7f3afc1c6..965b6dcaa 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Prod, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); return Status::OK(); } @@ -123,8 +123,8 @@ CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { // *** calculations *** // - auto products = input->reduceAlongDims(reduce::Prod, dimensions, true); - gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &products, gradI); + auto products = input->reduceAlongDimension(reduce::Prod, dimensions, true); + gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), products, *gradI); *gradI /= *input; if(!keepDims) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp index 00d277ec7..e42050ff6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp @@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(reduce_sqnorm, 1, 1, false, 0, 0) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - input->reduceAlongDimension(reduce::SquaredNorm, gradI, dimensions, keepDims); + input->reduceAlongDimension(reduce::SquaredNorm, *gradI, dimensions, keepDims); return Status::OK(); } @@ -86,7 +86,7 @@ DECLARE_TYPES(reduce_sqnorm) { ->setAllowedOutputTypes({ALL_FLOATS}); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp index 4631e4807..522164593 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { else if (block.getTArguments()->size()) keepDims = (bool)T_ARG(0); - input->reduceAlongDimension(reduce::Sum, output, dimensions, keepDims); + input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); return Status::OK(); } @@ -85,7 +85,7 @@ DECLARE_TYPES(reduce_sum) { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setSameMode(true); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_sum_bp) ////////////////////////////////////////////////////////////////////////// @@ -123,9 +123,9 @@ CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { if(!keepDims) { auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &r, gradI); + gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), r, *gradI); } else - gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), gradO, gradI); + gradI->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), *gradO, *gradI); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp index 7a5753edc..fbff41c47 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Rint, z); + x->applyTransform(transform::Rint, *z); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp index b5014bb7b..34da37897 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp @@ -30,7 +30,7 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); int extras = 2; - input->applyScalar(scalar::Pow, extras, output); + input->applyScalar(scalar::Pow, extras, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp index a43637788..81f81c326 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto out = OUTPUT_VARIABLE(0); // just for lulz - x->applyTransform(transform::Identity, out, nullptr); + x->applyTransform(transform::Identity, *out); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp index 63aa80e0a..ed7698d15 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp @@ -38,8 +38,8 @@ namespace nd4j { // axis might be dynamic (i.e. tf mode) helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->reduceAlongDimension(reduce::SquaredNorm, squares, axis); - input->reduceAlongDimension(reduce::Sum, sum, axis); + input->reduceAlongDimension(reduce::SquaredNorm, *squares, axis); + input->reduceAlongDimension(reduce::Sum, *sum, axis); auto count = NDArrayFactory::create(input->dataType(), input->lengthOf() / sum->lengthOf()); dataCount->assign(count); if (block.numT() > 0) { @@ -79,7 +79,7 @@ namespace nd4j { auto shapeList = SHAPELIST(scalarShape, sumShape, squareShape); if (block.numT() > 0) shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - + return shapeList; } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp index 090c29504..c76435622 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp @@ -38,18 +38,17 @@ namespace nd4j { REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, "Tear dimensions should be non-negative values, and lower then input rank. Got %i instead", v); auto tads = input->allTensorsAlongDimension(dims); - for (Nd4jLong e = 0; e < tads->size(); e++) { + for (Nd4jLong e = 0; e < tads.size(); e++) { auto outE = OUTPUT_VARIABLE(e); - outE->assign(tads->at(e)); + outE->assign(tads.at(e)); // just for debugging purposes this->storeResult(block, e, *outE); } - delete tads; - return Status::OK(); } + DECLARE_SHAPE_FN(tear) { auto inShape = inputShape->at(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp index f6ac319ab..a44510104 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp @@ -52,21 +52,20 @@ namespace nd4j { } auto tads = input->allTensorsAlongDimension(dims); - //nd4j_printf("Tad size: %d\n",tads->size()); - for (int e = 0; e < tads->size(); e++) { + //nd4j_printf("Tad size: %d\n",tads.size()); + for (int e = 0; e < tads.size(); e++) { //nd4j_printf("Calling assign at index %d\n",e); auto outE = OUTPUT_VARIABLE(e); - auto tadAtE = tads->at(e); + auto tadAtE = tads.at(e); outE->assign(tadAtE); this->storeResult(block, e, *outE); } - delete tads; - return Status::OK(); } + DECLARE_SYN(unpack, unstack); DECLARE_SHAPE_FN(unstack) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp index 4e86690b4..ce68df1a0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp @@ -41,7 +41,7 @@ namespace nd4j { MmulHelper::mmul(x, y, z, 1.0, 0.0); // adding b vector - z->addiRowVector(b); + z->addiRowVector(*b); return Status::OK(); } @@ -49,7 +49,7 @@ namespace nd4j { DECLARE_SHAPE_FN(xw_plus_b) { auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), inputShape->at(1), false, false, ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); - + return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp b/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp index 7754844d2..6ca57d297 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp @@ -73,7 +73,7 @@ CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) { auto xm = x; if(mask) { xm = new NDArray(x->getShapeInfo(), true, block.launchContext()); - x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xm, nullptr); + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *xm); } // time loop @@ -180,7 +180,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { // x = x * mask if(applyMask) - x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *x); // apply mask // multiplication matrix wi = matmul(w,x), U = WX auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N] @@ -226,52 +226,52 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { ///////////////// forward // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) - ft.addRowVector(&bF, &ft); - rt.addRowVector(&bR, &rt); - ft.applyTransform(transform::Sigmoid, nullptr, nullptr); - rt.applyTransform(transform::Sigmoid, nullptr, nullptr); + ft.addRowVector(bF, ft); + rt.addRowVector(bR, rt); + ft.applyTransform(transform::Sigmoid, ft); + rt.applyTransform(transform::Sigmoid, rt); // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); - ct.applyTransform(transform::Tanh, gct); + ct.applyTransform(transform::Tanh, *gct); // ftMinus = 1-ft, rtMinus = 1-rt - ft.applyTransform(transform::OneMinus, ftMinus); - rt.applyTransform(transform::OneMinus, rtMinus); + ft.applyTransform(transform::OneMinus, *ftMinus); + rt.applyTransform(transform::OneMinus, *rtMinus); ///////////////// backward // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; - gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt) - rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt; - temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; - inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; + gct->applyPairwiseTransform(pairwise::Subtract, xt, *temp1); // temp1 = (g_ct - xt) + rtMinus->applyPairwiseTransform(pairwise::Multiply, rt, *temp2); // temp2 = (1.0f - rt) * rt; + temp1->applyPairwiseTransform(pairwise::Multiply, *temp2); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; + inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, gradBRt); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; // bF, TODO - tanh // gradTanh = (1.0f - g_ct * g_ct); - gct->applyPairwiseTransform(pairwise::Multiply, gct, gradTanh, nullptr); // gradTanh = g_ct * g_ct - gradTanh->applyTransform(transform::OneMinus, gradTanh); // gradTanh = (1.0f - g_ct * g_ct) + gct->applyPairwiseTransform(pairwise::Multiply, *gct, *gradTanh); // gradTanh = g_ct * g_ct + gradTanh->applyTransform(transform::OneMinus, *gradTanh); // gradTanh = (1.0f - g_ct * g_ct) // gradCt = inGradHt * rt * gradTanh - rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, gradCt, nullptr); // gradCt = rt * gradTanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, gradCt, gradCt, nullptr); // gradCt = inGradHt * rt * gradTanh + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *gradCt); // gradCt = rt * gradTanh + inGradHt.applyPairwiseTransform(pairwise::Multiply, *gradCt, *gradCt); // gradCt = inGradHt * rt * gradTanh // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt) - ct_1->applyPairwiseTransform(pairwise::Subtract, &zt, temp2, nullptr); // temp2 = (ct_1 - zt) - temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft) - temp1->applyPairwiseTransform(pairwise::Multiply, &ft, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft)*ft - temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) + ct_1->applyPairwiseTransform(pairwise::Subtract, zt, *temp2); // temp2 = (ct_1 - zt) + temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft) + temp1->applyPairwiseTransform(pairwise::Multiply, ft, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft)*ft + temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, gradBFt); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); - inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr); + inGradHt.applyPairwiseTransform(pairwise::Multiply, *rtMinus, gradHXt); // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); - rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh - temp1->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh + inGradCt - temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, &gradUZt, nullptr); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *temp1); // temp1 = rt * grad_tanh + inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, *temp1); // temp1 = inGradHt * rt * grad_tanh + temp1->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = inGradHt * rt * grad_tanh + inGradCt + temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, gradUZt); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); gradUFt.assign(&gradBFt); gradURt.assign(&gradBRt); // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt) - temp1->applyPairwiseTransform(pairwise::Multiply, &ft, inGradCt, nullptr); // inGradCt = (gradCt + inGradCt) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) + temp1->applyPairwiseTransform(pairwise::Multiply, ft, *inGradCt); // inGradCt = (gradCt + inGradCt) * ft; if(t != 0) delete ct_1; @@ -283,9 +283,9 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { // gradX auto weightsT = w->transpose(); // [K x 3K] MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] - gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x + gradX->applyPairwiseTransform(pairwise::Add, *gradHX, *gradX); // + grad_highway_x if(applyMask) - gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask + gradX->applyBroadcast(broadcast::Multiply, {0,1}, *mask, *gradX); // apply mask // gradB auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K] @@ -296,7 +296,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K] delete gct; delete gradU; delete gradHX; - delete temp1; delete temp2; delete temp3; delete gradCt; delete wi; + delete temp1; delete temp2; delete gradCt; delete wi; delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias; return Status::OK(); @@ -941,7 +941,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask // // gradB -// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K] +// gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true); // [1 x 2K] // // gradW [bS x 3K x inSize] // x->permutei({0, 2, 1}); // [bS x time x inSize] diff --git a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp index 2cc454438..878c4c0a3 100644 --- a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp @@ -26,16 +26,16 @@ namespace nd4j { namespace ops { CUSTOM_OP_IMPL(tile_to_shape, 1, 1, true, 0, -1) { - + auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - + std::vector outShape(block.getIArguments()->begin(), block.getIArguments()->end()); if (block.isInplace()) { - input->tileToShape(outShape); + input->tileToShape(outShape, *input); } else { - input->tileToShape(outShape, output); + input->tileToShape(outShape, *output); } return Status::OK(); @@ -44,7 +44,7 @@ namespace ops { DECLARE_SHAPE_FN(tile_to_shape) { auto in = inputShape->at(0); - // output shape always equals to arguments + // output shape always equals to arguments auto conv = ArrayUtils::toLongVector(*block.getIArguments()); @@ -73,9 +73,9 @@ namespace ops { auto gradX = OUTPUT_VARIABLE(0); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(input->shapeInfo(), epsNext->shapeInfo()); - // FIX ME: reduceAlongDims should have a signature with result pass to avoid assigning twice + // FIX ME: reduceAlongDimension should have a signature with result pass to avoid assigning twice if (!axisX.empty()) { - auto tempRes = epsNext->reduceAlongDims(reduce::Sum, axisX); + auto tempRes = epsNext->reduceAlongDimension(reduce::Sum, axisX); gradX->assign(tempRes); } else gradX->assign(epsNext); diff --git a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp index f11d760b7..7c257b903 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp @@ -35,8 +35,8 @@ namespace nd4j { auto xO = OUTPUT_VARIABLE(0); auto yO = OUTPUT_VARIABLE(1); - x->applyScalar(scalar::Add, 1.0, xO, nullptr); - y->applyScalar(scalar::Add, 2.0, yO, nullptr); + x->applyScalar(scalar::Add, 1.0, *xO); + y->applyScalar(scalar::Add, 2.0, *yO); STORE_2_RESULTS(*xO, *yO); diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index 8ade17504..21164f520 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -63,11 +63,11 @@ namespace nd4j { sparse2dense.insert(pair); } - std::unique_ptr rows(x->allTensorsAlongDimension({1})); + ResultSet rows = x->allTensorsAlongDimension({1}); //PRAGMA_OMP_PARALLEL_FOR for (int r = 0; r < batchSize; r++) { - auto row = rows->at(r); + auto row = rows.at(r); for (int e = 0; e < numColumns; e += 2) { int idx = row->e(e); diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp index 45060ad43..e7cd1ccb9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp @@ -104,34 +104,34 @@ namespace nd4j { } nd4j::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, dims, exclusive, reverse); - std::unique_ptr val(output->dup()); + NDArray val = NDArray(output->dup()); - gradOut->applyPairwiseTransform(pairwise::Multiply, output, val.get(), nullptr); - val->applyPairwiseTransform(pairwise::Divide, input, val.get(), nullptr); + gradOut->applyPairwiseTransform(pairwise::Multiply, *output, val); + val.applyPairwiseTransform(pairwise::Divide, *input, val); if (!exclusive && !reverse) { if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, true, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, false, true); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, true); } else if (!exclusive && reverse){ if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, false, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, false, false); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, false, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, false); } else if (exclusive && !reverse) { if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, true, true); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, true); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, true, true); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, true); } else { if (dims.size()) - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, dims, true, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); else - nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, val.get(), output, true, false); + nd4j::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, false); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp index f89494fd1..5a8559075 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp @@ -29,7 +29,7 @@ namespace nd4j { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - first->applyTransform(transform::Floor, z, nullptr); + first->applyTransform(transform::Floor, *z); STORE_RESULT(*z); diff --git a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp index 06656b9de..7afe9b3ed 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp @@ -56,7 +56,7 @@ namespace ops { standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); - output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain); + output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, *gain, *output); if(bias != nullptr) { // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); // output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias); @@ -93,8 +93,8 @@ namespace ops { if(bias != nullptr) { REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); - // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); - eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + // eps->reduceAlongDimension(nd4j::reduce::Sum, *dLdb, {0}, true); + eps->reduceAlongDimension(nd4j::reduce::Sum, *dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); } NDArray standardized(input->shapeInfo(), false, block.launchContext()); @@ -106,18 +106,17 @@ namespace ops { std::vector bargs = {}; standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); - standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, *eps, standardized); + standardized.reduceAlongDimension(nd4j::reduce::Sum, *dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); nd4j::ops::standardize_bp standardizeBp; // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); - eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain, dLdx); + eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, *gain, *dLdx); auto dLdx_tmp = dLdx->dup(); - std::vector standardizeBpArgs = {input, dLdx_tmp}; + std::vector standardizeBpArgs = {input, &dLdx_tmp}; std::vector standardizeBpOut = {dLdx}; standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, bargs); - delete dLdx_tmp; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp b/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp index 3d45bcf42..ef9bdb925 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp @@ -29,10 +29,10 @@ namespace nd4j { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Log1p, z, nullptr); + x->applyTransform(transform::Log1p, *z); STORE_RESULT(z); - + return Status::OK(); } DECLARE_SYN(log1p, Log1p); diff --git a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp index 67ef3aa24..25efc1a73 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp @@ -29,28 +29,28 @@ namespace nd4j { namespace ops { CONFIGURABLE_OP_IMPL(standardize, 1, 1, true, 0, -2) { - + auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - - std::vector axis; - if (block.width() > 1) + std::vector axis; + + if (block.width() > 1) axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = *block.getIArguments(); + else if (block.numI() > 0) + axis = *block.getIArguments(); REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") shape::checkDimensions(input->rankOf(), axis); - auto means = input->reduceAlongDims(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDims(variance::SummaryStatsStandardDeviation, false, axis); + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); stdev.reshapei(means.getShapeAsVector()); - input->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &means, output, false); - output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &stdev, output, false); - output->applyScalar(nd4j::scalar::ReplaceNans, 0, output, nullptr); + input->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), means, *output, false); + output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), stdev, *output, false); + output->applyScalar(nd4j::scalar::ReplaceNans, 0, *output); return Status::OK(); } @@ -69,9 +69,9 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); std::vector axis; - if (block.width() == 3) + if (block.width() == 3) axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) + else if (block.numI() > 0) axis = *block.getIArguments(); REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") @@ -80,13 +80,13 @@ namespace ops { shape::checkDimensions(input->rankOf(), axis); auto longAxis = ArrayUtils::toLongVector(axis); - auto means = input->reduceAlongDims(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDims(variance::SummaryStatsStandardDeviation, false, axis); + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); stdev.reshapei(means.getShapeAsVector()); - eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &stdev, output, false); + eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), stdev, *output, false); - auto dldu_sum = -output->reduceAlongDims(reduce::Sum, axis, true); + NDArray dldu_sum = -output->reduceAlongDimension(reduce::Sum, axis, true); NDArray dldx_u(input->shapeInfo(), false, block.launchContext()); std::vector meanBpArgs = {input, &dldu_sum}; @@ -100,12 +100,12 @@ namespace ops { // (eps * (means - input) / (stdev * stdev)) NDArray tmp(eps->shapeInfo(), false, block.launchContext()); - means.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), input, &tmp, false); - tmp.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &tmp, nullptr); - stdev.applyPairwiseTransform(nd4j::pairwise::Multiply, &stdev, &stdev, nullptr); - tmp.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), &stdev, &tmp, false); + means.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), *input, tmp, false); + tmp.applyPairwiseTransform(nd4j::pairwise::Multiply, *eps, tmp); + stdev.applyPairwiseTransform(nd4j::pairwise::Multiply, stdev, stdev); + tmp.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Divide(), stdev, tmp, false); - auto dlds_sum = tmp.reduceAlongDims(reduce::Sum, axis, true); + auto dlds_sum = tmp.reduceAlongDimension(reduce::Sum, axis, true); NDArray dldx_s(input->shapeInfo(), false, block.launchContext()); std::vector stdevBpArgs = {input, &dlds_sum}; std::vector stdevBpOutput = {&dldx_s}; @@ -115,7 +115,7 @@ namespace ops { stdevBp.execute(stdevBpArgs, stdevBpOutput, stdevBpTArgs, longAxis, stdevBpBArgs); *output += dldx_s; - output->applyScalar(nd4j::scalar::ReplaceNans, 0, output, nullptr); + output->applyScalar(nd4j::scalar::ReplaceNans, 0, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/tri.cpp b/libnd4j/include/ops/declarable/generic/transforms/tri.cpp index 727d42ba5..a6106f197 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tri.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tri.cpp @@ -32,8 +32,8 @@ CUSTOM_OP_IMPL(tri, -2, 1, false, 0, 1) { const int diag = block.numI() > 2 ? INT_ARG(2) : 0; - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (1., diag + 1, 0, 'l'), LIBND4J_TYPES); // fill with unities lower triangular block of matrix - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (0., 0, diag, 'u'), LIBND4J_TYPES); // fill with zeros upper triangular block of matrix + BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (1., diag + 1, 0, *output, 'l'), LIBND4J_TYPES); // fill with unities lower triangular block of matrix + BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (0., 0, diag, *output, 'u'), LIBND4J_TYPES); // fill with zeros upper triangular block of matrix // output->setValueInDiagMatrix(1., diag, 'l'); // output->setValueInDiagMatrix(0., diag+1, 'u'); diff --git a/libnd4j/include/ops/declarable/generic/transforms/triu.cpp b/libnd4j/include/ops/declarable/generic/transforms/triu.cpp index 1c4214e9b..b382cbfb1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/triu.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/triu.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(triu, 1, 1, false, 0, 0) { const int diag = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; - BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, 'l', output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, *output, 'l' ), LIBND4J_TYPES); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp index a7123d42f..f8704d7b0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp @@ -195,7 +195,7 @@ namespace helpers { return res; }; - input->applyTriplewiseLambda(gradX, epsilon, gainsInternal, output); + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); } void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp index ba0f36eb5..9a11baf37 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp @@ -115,9 +115,9 @@ static void softMaxForVector_(void *input, Nd4jLong *inShapeInfo, void *output, BUILD_SINGLE_SELECTOR(input.dataType(), _softMaxDerivForVector, (context, input.getBuffer(), input.getShapeInfo(), output.buffer()), FLOAT_TYPES); } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; output *= (1.f - output); // derivative } @@ -204,7 +204,7 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra else output = 1.; } - else if(input.isSameShapeStrict(&output)) { + else if(input.isSameShapeStrict(output)) { TadPack tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimension); Nd4jLong* tadShapeInfo = tadPack.primaryShapeInfo(); @@ -275,10 +275,10 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra } } else { - NDArray max = input.reduceAlongDims(nd4j::reduce::Max, {dimension}, true); - input.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &max, &output, false); - output.applyTransform(nd4j::transform::Exp); - NDArray sum = output.reduceAlongDims(nd4j::reduce::Sum, {dimension}, true); + NDArray max = input.reduceAlongDimension(nd4j::reduce::Max, {dimension}, true); + input.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), max, output, false); + output.applyTransform(nd4j::transform::Exp, output); + NDArray sum = output.reduceAlongDimension(nd4j::reduce::Sum, {dimension}, true); output /= sum; } } @@ -347,7 +347,7 @@ void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& auto routine = LAMBDA_T(_x, threshold) { return _x > (T)threshold? _x: (T)0.f; }; - const_cast(input).applyLambda(routine, &output); + const_cast(input).applyLambda(routine, output); } void thresholdRelu(nd4j::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { @@ -358,7 +358,7 @@ void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& static void thresholdReluDerivative_(nd4j::LaunchContext * context, NDArray* input, double theta, NDArray* dLdO, NDArray* output) { auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - input->applyPairwiseLambda(dLdO, derivative, output); + input->applyPairwiseLambda(*dLdO, derivative, *output); } @@ -381,11 +381,11 @@ void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; - output.applyTransform(transform::Log); + output.applyTransform(transform::Log, output); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp index a64864b1b..204ceaf81 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp @@ -141,7 +141,7 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output, } else { const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last - const_cast(input).applyBroadcast(nd4j::broadcast::Add, {channelDim}, &bias, &output); + const_cast(input).applyBroadcast(nd4j::broadcast::Add, {channelDim}, bias, output); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp index 7a0d8b97b..ada2c5d72 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp @@ -40,11 +40,11 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* if(gamma != nullptr) { auto lambda = LAMBDA_TT(x, y, eps) {return x / nd4j::math::nd4j_sqrt(y + eps);}; - const_cast(gamma)->applyPairwiseLambda(variance, lambda, &sigmaInvGam); + const_cast(gamma)->applyPairwiseLambda(*variance, lambda, sigmaInvGam); } else { auto lambda = LAMBDA_T(x, eps) { return 1. / nd4j::math::nd4j_sqrt(x + eps); }; - const_cast(variance)->applyLambda(lambda, &sigmaInvGam); + const_cast(variance)->applyLambda(lambda, sigmaInvGam); } // auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); // sigmaInvGam = 1 / sqrt(variance + epsilon) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp index e2d24c591..4f8989caf 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp @@ -28,7 +28,7 @@ namespace helpers { template void _confusionFunctor(NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - std::unique_ptr arrs(output->allTensorsAlongDimension({1})); + ResultSet arrs = output->allTensorsAlongDimension({1}); int lLen = labels->lengthOf(); auto func = PRAGMA_THREADS_FOR { @@ -36,7 +36,7 @@ namespace helpers { auto label = labels->e(j); auto pred = predictions->e(j); T value = (weights == nullptr ? (T) 1.0f : weights->e(j)); - (*arrs->at(label)).p(pred, value); + arrs.at(label)->p(pred, value); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 47938e9fb..db09f0d3c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -373,7 +373,7 @@ namespace nd4j { NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; } @@ -506,7 +506,7 @@ namespace nd4j { NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp index 3150c0cfd..0adb0e249 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp @@ -36,23 +36,19 @@ void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray auto tadsB = _b.allTensorsAlongDimension({1}); auto tadsO = _o.allTensorsAlongDimension({1}); - int tads = tadsA->size(); + int tads = tadsA.size(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - auto a_ = tadsA->at(e); - auto b_ = tadsB->at(e); - auto o_ = tadsO->at(e); + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); helpers::cross(context, a_, b_, o_); } }; samediff::Threads::parallel_tad(func, 0, tads); - - delete tadsA; - delete tadsB; - delete tadsO; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp index 073167f18..281e6c809 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp @@ -34,7 +34,7 @@ namespace nd4j { for (int i = sourceDimsLen; i > 0; i--) sourceDims[sourceDimsLen - i] = input->rankOf() - i; - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(sourceDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(sourceDims); unsigned int outSize = outputList.size(); @@ -48,15 +48,14 @@ namespace nd4j { for (int k = 1; k < r; k++) outDims[k - 1] = k; - std::unique_ptr listOutForCurrent( - outputs[i].first->allTensorsAlongDimension(outDims)); + ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); outputs[i].second = 0; //PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > Environment::getInstance()->elementwiseThreshold()) for (int e = 0; e < indices->lengthOf(); ++e) if ((*indices).e(e) == i) - listOutForCurrent->at(outputs[i].second++)->assign(listOfTensors->at(e)); + listOutForCurrent.at(outputs[i].second++)->assign(listOfTensors.at(e)); } } else { @@ -104,7 +103,7 @@ namespace nd4j { for (int i = restDims.size(); i > 0; i--) restDims[restDims.size() - i] = output->rankOf() - i; - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (int e = 0; e < numOfData; e++) { auto data = inputs[e]; @@ -113,7 +112,7 @@ namespace nd4j { for (int i = sourceDims.size(); i > 0; i--) sourceDims[sourceDims.size() - i] = data->rankOf() - i; - std::unique_ptr listOfTensors(data->allTensorsAlongDimension(sourceDims)); + ResultSet listOfTensors = data->allTensorsAlongDimension(sourceDims) ; for (int i = 0; i < index->lengthOf(); i++) { auto pos = index->e(i); @@ -127,7 +126,7 @@ namespace nd4j { return ND4J_STATUS_VALIDATION; } - listOfOutTensors->at(pos)->assign(listOfTensors->at(i)); + listOfOutTensors.at(pos)->assign(listOfTensors.at(i)); } } } @@ -145,7 +144,7 @@ namespace nd4j { for (int i = sourceDimsLen; i > 0; i--) sourceDims[sourceDimsLen - i] = input->rankOf() - i; - std::unique_ptr listOfTensors(outputList[0]->allTensorsAlongDimension(sourceDims)); + ResultSet listOfTensors = outputList[0]->allTensorsAlongDimension(sourceDims); for (unsigned int i = 0; i < inputGradientList.size(); i++) { outputs[i].first = inputGradientList[i]; @@ -155,14 +154,13 @@ namespace nd4j { for (int k = 1; k < outputs[i].first->rankOf(); k++) outDims[k - 1] = k; - std::unique_ptr listOutForCurrent( - outputs[i].first->allTensorsAlongDimension(outDims)); + ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); outputs[i].second = 0; for (int e = 0; e < indices->lengthOf(); ++e) if (indices->e(e) == i) - listOfTensors->at(e)->assign(listOutForCurrent->at(outputs[i].second++)); + listOfTensors.at(e)->assign(listOutForCurrent.at(outputs[i].second++)); } } else { // one-dimensional case diff --git a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp index f3fe89103..0a46c995e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp @@ -28,14 +28,14 @@ namespace helpers { template static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){ std::vector restDims({1, 2, 3}); // the first and the last dims - std::unique_ptr listOfMatricies(images->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutputs(output->allTensorsAlongDimension(restDims)); + ResultSet listOfMatricies = images->allTensorsAlongDimension(restDims); + ResultSet listOfOutputs = output->allTensorsAlongDimension(restDims); // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) //int e = 0; const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); const int ksize = ksizeRowsEffective * ksizeColsEffective; - int batchCount = listOfMatricies->size(); //lengthOf() / ksize; + int batchCount = listOfMatricies.size(); //lengthOf() / ksize; Nd4jLong lastDim = images->sizeAt(3); Nd4jLong outLastDim = output->sizeAt(3); Nd4jLong rowDim = images->sizeAt(1); @@ -51,8 +51,8 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto batch = 0; batch < stop; batch += increment) { - auto patch = listOfMatricies->at(batch); - auto outMatrix = listOfOutputs->at(batch); + auto patch = listOfMatricies.at(batch); + auto outMatrix = listOfOutputs.at(batch); for (Nd4jLong i = 0; i < outRowDim; i++) { for (Nd4jLong j = 0; j < outColDim; j++) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index 81f00b066..f18f48fac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -105,7 +105,7 @@ namespace helpers { return (nd4j::math::nd4j_floor(val / scale + T(0.5f)) * scale + nudgedMin); }; - input->applyLambda(fakeQuantizationWithMinMax, output); + input->applyLambda(fakeQuantizationWithMinMax, *output); } void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp index ecc4cf24a..f6756dd88 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp @@ -30,7 +30,7 @@ static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, return _x - (_y * weight); }; - input->applyPairwiseLambda(step, lambda, output); + input->applyPairwiseLambda(*step, lambda, *output); } void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp index 579ab2612..3db5e5373 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp @@ -77,15 +77,15 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa // reset gate r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid); + r->applyTransform(transform::Sigmoid, *r); // update gate u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid); + u->applyTransform(transform::Sigmoid, *u); // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh); + c->applyTransform(transform::Tanh, *c); NDArray temp = 1.f - *c * *c; @@ -231,15 +231,15 @@ void gruCellBP(nd4j::LaunchContext* context, // reset gate NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid); + r.applyTransform(transform::Sigmoid, r); // update gate NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid); + u.applyTransform(transform::Sigmoid, u); // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh); + c.applyTransform(transform::Tanh, c); // h = (1 - u) * c + u * hPrev @@ -352,10 +352,10 @@ void gruCellBP(nd4j::LaunchContext* context, dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] } // ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index 62f8316ce..4db975ddf 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -31,7 +31,7 @@ namespace helpers { return x > (T) 0.f ? y : T(0.f); }; - theFirst->applyPairwiseLambda(theSecond, functor, nullptr); + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { @@ -46,7 +46,7 @@ namespace helpers { return x > zero ? y : zero; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); /* auto x = input->bufferAsT(); @@ -74,7 +74,7 @@ namespace helpers { return x > (T)0.f && x < (T)6.f? y : T(0.f); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -90,7 +90,7 @@ namespace helpers { return x < 0 ? alphaT * y : y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -106,7 +106,7 @@ namespace helpers { return y * nd4j::math::nd4j_eluderivative(x, alphaT); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -119,7 +119,7 @@ namespace helpers { return y * simdOps::SELUDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -132,7 +132,7 @@ namespace helpers { return y * (3 * x * x); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -146,7 +146,7 @@ namespace helpers { return x > T(0.f)? y : -y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -160,7 +160,7 @@ namespace helpers { return nd4j::math::nd4j_max(x, (T)0.f) - x * y + nd4j::math::nd4j_log((T)1.f + nd4j::math::nd4j_exp(-nd4j::math::nd4j_abs(x))); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } void sigmCrossEntropy(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { @@ -178,7 +178,7 @@ namespace helpers { return static_cast(1.) - y - e / (static_cast(1.) + e); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { @@ -193,7 +193,7 @@ namespace helpers { return y * ((T)1.0f - (th * th)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -208,7 +208,7 @@ namespace helpers { return y * simdOps::HardTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -221,7 +221,7 @@ namespace helpers { return y * simdOps::RationalTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -234,7 +234,7 @@ namespace helpers { return x > (T) 0.0f ? y * (nd4j::math::nd4j_tanhderivative(x)) : (T) 0.0f; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -251,7 +251,7 @@ namespace helpers { return y * ((T) 1.0f / (ss * ss)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void softSignDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -265,7 +265,7 @@ namespace helpers { return y * (p / (p + 1.)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -282,7 +282,7 @@ namespace helpers { return y * (s * ((T) 1.0f - s)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -295,7 +295,7 @@ namespace helpers { return y * simdOps::HardSigmoidDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -305,24 +305,24 @@ namespace helpers { template static void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyTransform(transform::Exp, tempInput.get()); + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { axisVector.resize(axis->lengthOf()); for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } template static void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyPairwiseTransform(pairwise::Subtract, subtrah, tempInput.get(), nullptr); - tempInput->applyTransform(transform::Exp, nullptr, nullptr); + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { @@ -330,8 +330,8 @@ namespace helpers { for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { @@ -364,16 +364,16 @@ static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArr if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(targets), mainRoutineT1, output); + const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); } else { std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); std::unique_ptr targetTensor(new NDArray(*targets)); *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(targets), targetTensor.get(), mainRoutineT2, output); + const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 922fdc3a9..683a82392 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -86,7 +86,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue); + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); if(peephole) zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc @@ -99,7 +99,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] // if clipping projection is provided then projected cell output state is clipped by this value if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue); + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); } else ht->assign(&htNoPeepHole); @@ -199,13 +199,13 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast PRAGMA_OMP_SINGLE { PRAGMA_OMP_TASK - zz.applyTransform(transform::Tanh, z); //z = tanh(zz) + zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) PRAGMA_OMP_TASK - zi.applyTransform(transform::Sigmoid, i); //i = sigmoid(zi) + zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) PRAGMA_OMP_TASK - zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf); + zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); } if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && f->ews() == 1 && h->ews() == 1 && @@ -214,15 +214,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast BUILD_SINGLE_SELECTOR(z->dataType(), fusedTanh, (z, i, c, cLast, f, h), FLOAT_TYPES); } else { //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i auto temp = (*f) * (*cLast); *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, h); //h = tanh(c) + c->applyTransform(transform::Tanh, *h); //h = tanh(c) } // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue); + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); // add peephole connections to output gate zot + ct*Wc if(peephole) { @@ -230,11 +230,11 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast zo += prod; } - zo.applyTransform(transform::Sigmoid, o); // o = sigmoid(zo) + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, h, y, nullptr); //y = o * h + c->applyTransform(transform::Tanh, *h); //h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index d706eaff3..9c7cb1bfe 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -264,14 +264,14 @@ namespace helpers { auto n = input->sizeAt(-1); output->assign(input); // fill up output tensor with zeros - std::unique_ptr outputs(output->allTensorsAlongDimension({-2, -1})); - std::unique_ptr permutations(permutationVectors->allTensorsAlongDimension({-1})); + ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); + ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1}); auto loop = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { - luNN_(context, outputs->at(i), permutations->at(i), n); + luNN_(context, outputs.at(i), permutations.at(i), n); } }; - samediff::Threads::parallel_for(loop, 0, outputs->size(), 1); + samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); } void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { @@ -384,13 +384,13 @@ template template static bool checkCholeskyInput_(nd4j::LaunchContext * context, NDArray const* input) { //std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.getWorkspace()); - std::unique_ptr lastMatrixList(input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1})); - for (size_t i = 0; i < lastMatrixList->size(); i++) { - auto thisMatrix = lastMatrixList->at(i); + ResultSet lastMatrixList = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1}); + for (size_t i = 0; i < lastMatrixList.size(); i++) { + auto thisMatrix = lastMatrixList.at(i); // check for symmetric for (Nd4jLong r = 0; r < thisMatrix->rows(); r++) for (Nd4jLong c = 0; c < thisMatrix->columns(); c++) - if (nd4j::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList->at(i)->e(c,r)) > DataTypeUtils::min()) return false; + if (nd4j::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList.at(i)->e(c,r)) > DataTypeUtils::min()) return false; NDArray output = NDArrayFactory::create(0., context); if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false; @@ -459,21 +459,18 @@ template template int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { - std::unique_ptr tempOutput(input->dup()); - int res = cholesky_(context, input, tempOutput.get(), false); + auto tempOutput = input->dup(); + int res = cholesky_(context, input, &tempOutput, false); if (res != ND4J_STATUS_OK) return res; auto n = input->sizeAt(-1); auto totalCount = output->lengthOf(); std::vector d(n); - std::unique_ptr matricies(tempOutput->allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1})); - std::unique_ptr inputMatricies(input->allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1})); - for (Nd4jLong e = 0; e < totalCount; e++) { + ResultSet matricies = tempOutput.allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1}); - //d[0] = inputMatricies->at(e)->t(0, 0); - for (size_t i = 0; i < n; ++i) { - output->t(e) += nd4j::math::nd4j_log(nd4j::math::nd4j_pow(matricies->at(e)->t(i, i), T(2))); - } + for (Nd4jLong e = 0; e < totalCount; e++) { + for (size_t i = 0; i < n; ++i) + output->t(e) += nd4j::math::nd4j_log(nd4j::math::nd4j_pow(matricies.at(e)->t(i, i), T(2))); } return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp index 399d89e32..fbab49e80 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp @@ -30,11 +30,11 @@ namespace helpers { Nd4jLong N = input->sizeAt(-1); Nd4jLong lastDim = input->rankOf() - 1; Nd4jLong preLastDim = input->rankOf() - 2; - std::unique_ptr listOut(output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - std::unique_ptr listDiag(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - for (Nd4jLong e = 0; e < listOut->size(); ++e) { - NDArray* inputMatrix = listDiag->at(e); - NDArray* outputMatrix = listOut->at(e); + ResultSet listOut = output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + ResultSet listDiag = input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + for (Nd4jLong e = 0; e < listOut.size(); ++e) { + NDArray* inputMatrix = listDiag.at(e); + NDArray* outputMatrix = listOut.at(e); if (outputMatrix != inputMatrix) // if not inplace outputMatrix->assign(inputMatrix); if (lowerBand >= 0) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp index e0e487e82..cc43c1866 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp @@ -37,24 +37,21 @@ int _matrixDiagPart(const NDArray* input, NDArray* output) { auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); - if (listOut->size() != listDiag->size()) { + if (listOut.size() != listDiag. size()) { nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); return ND4J_STATUS_VALIDATION; } int lastDimension = nd4j::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); // TODO: tune this properlys - int lO = listOut->size(); + int lO = listOut.size(); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) for (int j = 0; j < lastDimension; ++j) - listOut->at(i)->p(j, listDiag->at(i)->e(j, j)); + listOut.at(i)->p(j, listDiag.at(i)->e(j, j)); }; samediff::Threads::parallel_tad(func, 0, lO); - - delete listOut; - delete listDiag; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp index b59c16afe..a8a0d919d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp @@ -39,13 +39,11 @@ void meshgrid(nd4j::LaunchContext * context, const std::vector& inArrs inIndices[0] = 1; inIndices[1] = 0; } - - for(int i = 0; i < rank; ++i) { - auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); - for(int j = 0; j < list->size(); ++j) - list->at(j)->assign(inArrs[i]); - delete list; + for(int i = 0; i < rank; ++i) { + auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); + for(int j = 0; j < list.size(); ++j) + list.at(j)->assign(inArrs[i]); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp index 61b6465ba..8d94d23ca 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { namespace helpers { - template + template static void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { auto lambdaX = LAMBDA_TTT(_e, _x, _y) { @@ -43,10 +43,10 @@ namespace helpers { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -60,8 +60,8 @@ namespace helpers { gradY->assign(tmp); else gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -71,8 +71,8 @@ namespace helpers { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -81,22 +81,16 @@ namespace helpers { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } @@ -116,10 +110,10 @@ namespace helpers { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -133,8 +127,8 @@ namespace helpers { gradY->assign(tmp); else gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -144,8 +138,8 @@ namespace helpers { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -154,22 +148,16 @@ namespace helpers { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; - } else + } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp index 8c5332be6..dcca5075e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp @@ -51,12 +51,12 @@ namespace helpers { SpecialMethods::sortTadGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), lastDims.data(), lastDims.size(), pack.primaryShapeInfo(), pack.primaryOffsets(), reverse); - std::unique_ptr rows(sortedVals.allTensorsAlongDimension(lastDims)); + ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); Nd4jLong oL = output->lengthOf(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - auto row = rows->at(e); + auto row = rows.at(e); output->p(e, row->e(n)); } }; @@ -70,7 +70,7 @@ namespace helpers { } BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); - + } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp index 5c1f3c28d..fa8061e54 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp @@ -30,8 +30,8 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// template static void _percentile(const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - - const int inputRank = input.rankOf(); + + const int inputRank = input.rankOf(); if(axises.empty()) for(int i=0; i& auto listOfSubArrs = input.allTensorsAlongDimension(axises); - - std::vector shapeOfSubArr(listOfSubArrs->at(0)->rankOf()); + + std::vector shapeOfSubArr(listOfSubArrs.at(0)->rankOf()); for(int i=0; iat(0)->shapeOf()[i]; + shapeOfSubArr[i] = listOfSubArrs.at(0)->shapeOf()[i]; auto flattenedArr = NDArrayFactory::create('c', shapeOfSubArr, input.dataType(), input.getContext()); const int len = flattenedArr.lengthOf(); - + const float fraction = 1.f - q / 100.; Nd4jLong position = 0; - + switch(interpolation) { case 0: // lower position = static_cast(math::nd4j_ceil((len - 1) * fraction)); @@ -67,15 +67,13 @@ static void _percentile(const NDArray& input, NDArray& output, std::vector& // FIXME: our sort impl should be used instead, so this operation might be implemented as generic // FIXME: parallelism ! - for(int i=0; isize(); ++i) { - + for(int i=0; i(flattenedArr.getBuffer()); - flattenedArr.assign(listOfSubArrs->at(i)); + flattenedArr.assign(listOfSubArrs.at(i)); std::sort(buff, buff + len); output.p(i, flattenedArr.e(position)); } - - delete listOfSubArrs; } void percentile(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp index f46346876..43c65f14b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp @@ -101,17 +101,14 @@ namespace nd4j { static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { auto xTads = x->allTensorsAlongDimension(dims); auto zTads = z->allTensorsAlongDimension(dims); - auto t = xTads->size(); + auto t = xTads.size(); for (int e = 0; e < t; e++) { - auto tx = xTads->at(e); - auto tz = zTads->at(e); + auto tx = xTads.at(e); + auto tz = zTads.at(e); prefix_(op, tx->buffer(), tx->shapeInfo(), tz->buffer(), tz->shapeInfo(), exclusive, reverse); } - - delete xTads; - delete zTads; }; template diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index 3f9788330..f25859b1c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -46,8 +46,8 @@ namespace helpers { NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); - copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha)); - copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta)); + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); } // bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c'; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp index 9f424606d..9ee906bd5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp @@ -171,27 +171,21 @@ static void reverseSequence_(nd4j::LaunchContext * context, const NDArray* input auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - for(int i = 0; i < inSubArrsSet->size(); ++i) { + for(int i = 0; i < inSubArrsSet.size(); ++i) { Nd4jLong numOfElemsToReverse = seqLengths->e(i); if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet->at(i)->assign(inSubArrsSet->at(i)); + outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); } else { - auto inInnerSet = inSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet->size(); ++j) - helpers::reverseArray(context, inInnerSet->at(j)->getBuffer(), inInnerSet->at(j)->getShapeInfo(), outInnerSet->at(j)->getBuffer(), outInnerSet->at(j)->getShapeInfo(), numOfElemsToReverse); - - delete inInnerSet; - delete outInnerSet; + auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + for(int j = 0; j < inInnerSet.size(); ++j) + helpers::reverseArray(context, inInnerSet.at(j)->getBuffer(), inInnerSet.at(j)->getShapeInfo(), outInnerSet.at(j)->getBuffer(), outInnerSet.at(j)->getShapeInfo(), numOfElemsToReverse); } } - delete inSubArrsSet; - delete outSubArrsSet; } - } void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { @@ -209,14 +203,11 @@ void reverse(nd4j::LaunchContext * context, const NDArray* input, NDArray* outpu NDArray *subArrIn, *subArrOut; - for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size() - subArrIn = listIn->at(i); - subArrOut = listOut->at(i); + for(int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() + subArrIn = listIn.at(i); + subArrOut = listOut.at(i); BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn->getBuffer(), subArrIn->getShapeInfo(), subArrOut->getBuffer(), subArrOut->getShapeInfo()), LIBND4J_TYPES); } - - delete listOut; - delete listIn; } BUILD_SINGLE_TEMPLATE(template void reverseSequence_, (nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim), LIBND4J_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index b3b65f816..8bfc1ca1a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -40,8 +40,8 @@ namespace helpers { if (actualShift) { int shiftCount = fullLen / actualShift - 1; - int remainShift = fullLen % actualShift; - + int remainShift = fullLen % actualShift; + // stage 1) swap last actualShift elements with first ones. //PRAGMA_OMP_PARALLEL_FOR //_IF(actualShift > Environment::getInstance()->elementwiseThreshold()) for (int e = 0; e < actualShift; ++e) { @@ -70,7 +70,7 @@ namespace helpers { output->p(sourceIndex, _e0); } } - + // stage 3) swap remainer of items. if (remainShift && shiftCount) for (int i = actualShift; i < 2 * actualShift; ++i) { @@ -94,9 +94,9 @@ namespace helpers { for (size_t i = 0; i < axes.size(); i++) { int axe = axes[i]; if (axe == source->rankOf() - 1) {// last dimension - std::unique_ptr listOfTensors(source->allTensorsAlongDimension({axe})); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); - int fullLen = listOfTensors->size(); + ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); int theShift = shifts[i]; if (theShift > 0) { theShift %= fullLen; @@ -105,7 +105,7 @@ namespace helpers { theShift -= fullLen * (theShift / fullLen - 1); } for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(context, listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); + rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); } } else { @@ -113,10 +113,10 @@ namespace helpers { for (int i = 0; i < dims.size(); ++i) dims[i] = axe + 1 + i; - std::unique_ptr listOfTensors(source->allTensorsAlongDimension({dims})); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({dims})); + ResultSet listOfTensors = source->allTensorsAlongDimension({dims}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({dims}); // - int fullLen = listOfTensors->size(); + int fullLen = listOfTensors.size(); int sizeAt = input->sizeAt(axe); int theShift = shifts[i]; @@ -131,16 +131,16 @@ namespace helpers { if (theShift) { for (int dim = 0; dim < fullLen / sizeAt; ++dim) { for (int e = theShift; e < sizeAt - theShift; ++e) { - auto sourceM = listOfTensors->at(dim * sizeAt + e - theShift); - auto targetM = listOfOutTensors->at(dim * sizeAt + e); + auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); sourceM->swapUnsafe(*targetM); } - + for (int e = 0; e < theShift; ++e) { int sourceIndex = dim * sizeAt + sizeAt - theShift + e; - auto sourceM = listOfTensors->at(sourceIndex); - auto targetM = listOfOutTensors->at(dim * sizeAt + e); - + auto sourceM = listOfTensors.at(sourceIndex); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); + sourceM->swapUnsafe(*targetM); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index 9ae191c76..a3f0c01be 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -83,7 +83,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind Nd4jLong idx = indices.e(i); NDArray out = output({idx, idx + 1}); - out.applyPairwiseTransform(op, updates.e(i), nullptr); + out.applyPairwiseTransform(op, updates.e(i)); } }; @@ -103,7 +103,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind NDArray outSubArr = output(indices.e(i), std::vector({0})); NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr, nullptr); + outSubArr.applyPairwiseTransform(op, updSubArr); } }; @@ -150,7 +150,7 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i NDArray outSubArr = output(idxRangeOut); NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr, nullptr); + outSubArr.applyPairwiseTransform(op, updSubArr); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 2884107f3..5e21e3b8e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -56,31 +56,29 @@ namespace helpers { auto numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto maxT = listOfOutTensors->at(idx); + auto maxT = listOfOutTensors.at(idx); //int pos = 0; - maxT->assign(listOfTensors->at(0)); + maxT->assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { - maxT->t(e) = nd4j::math::nd4j_max(maxT->t(e), listOfTensors->at(i)->t(e)); + maxT->t(e) = nd4j::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); } } else { idx = indices->e(i); - maxT = listOfOutTensors->at(idx); - maxT->assign(listOfTensors->at(i)); + maxT = listOfOutTensors.at(idx); + maxT->assign(listOfTensors.at(i)); } } - delete listOfTensors; - delete listOfOutTensors; } } - // segmen min + // segmen min template static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) { //int numClasses = output->sizeAt(0); @@ -91,7 +89,7 @@ namespace helpers { for (int e = 1; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // min + // min val = nd4j::math::nd4j_min(val, input->t(e)); } else { @@ -104,27 +102,27 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors( input->allTensorsAlongDimension(restDims) ); - std::unique_ptr listOfOutTensors( output->allTensorsAlongDimension(restDims) ); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); int numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto minT = listOfOutTensors->at(idx); + auto minT = listOfOutTensors.at(idx); int pos = 0; - minT->assign(listOfTensors->at(0)); + minT->assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { for (int e = 0; e < minT->lengthOf(); e++) { - minT->p(e, nd4j::math::nd4j_min(minT->e(e), listOfTensors->at(i)->e(e))); + minT->p(e, nd4j::math::nd4j_min(minT->e(e), listOfTensors.at(i)->e(e))); } } else { idx = indices->e(i); - minT = listOfOutTensors->at(idx); - minT->assign(listOfTensors->at(i)); + minT = listOfOutTensors.at(idx); + minT->assign(listOfTensors.at(i)); } } } @@ -142,7 +140,7 @@ namespace helpers { for (int e = 0; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // mean + // mean val += input->e(e); count++; } @@ -163,16 +161,16 @@ namespace helpers { int numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto meanT = listOfOutTensors->at(idx); + auto meanT = listOfOutTensors.at(idx); int count = 1; auto meanV = meanT->dup(); - meanV->assign(listOfTensors->at(0)); + meanV.assign(listOfTensors.at(0)); for (int i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - meanV->p(e, meanV->e(e) + listOfTensors->at(i)->e(e)); + meanV.p(e, meanV.e(e) + listOfTensors.at(i)->e(e)); } }; samediff::Threads::parallel_for(func, 0, meanT->lengthOf()); @@ -181,17 +179,14 @@ namespace helpers { } else { //meanT->assign(meanV); - meanV->applyScalar(scalar::Divide, count, meanT, nullptr); + meanV.applyScalar(scalar::Divide, count, *meanT); idx = indices->e(i); - meanT = listOfOutTensors->at(idx); - meanV->assign(listOfTensors->at(i)); + meanT = listOfOutTensors.at(idx); + meanV.assign(listOfTensors.at(i)); count = 1; } - meanV->applyScalar(scalar::Divide, count, meanT, nullptr); + meanV.applyScalar(scalar::Divide, count, *meanT); } - delete meanV; - delete listOfTensors; - delete listOfOutTensors; } } @@ -205,7 +200,7 @@ namespace helpers { int count = 0; for (int e = 0; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // sum + // sum val += input->t(e); } else { @@ -223,25 +218,23 @@ namespace helpers { int numOfClasses = output->sizeAt(0); // number of classes std::vector> outputs(numOfClasses); - auto sumT = listOfOutTensors->at(idx); + auto sumT = listOfOutTensors.at(idx); for (int i = 0; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - sumT->p(e, sumT->e(e) + listOfTensors->at(i)->e(e)); + sumT->p(e, sumT->e(e) + listOfTensors.at(i)->e(e)); } }; samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); } else { idx = indices->e(i); - sumT = listOfOutTensors->at(idx); - sumT->assign(listOfTensors->at(i)); + sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(i)); } } - delete listOfTensors; - delete listOfOutTensors; } } @@ -257,7 +250,7 @@ namespace helpers { for (int e = 1; e < indices->lengthOf(); e++) { if (idx == indices->e(e)) { - // sum + // sum val *= input->e(e); } else { @@ -274,25 +267,23 @@ namespace helpers { auto listOfOutTensors = output->allTensorsAlongDimension(restDims); int numOfClasses = output->sizeAt(0); // number of classes - auto sumT = listOfOutTensors->at(idx); - sumT->assign(listOfTensors->at(0)); + auto sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(0)); for (int i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - sumT->p(e, sumT->e(e) * listOfTensors->at(i)->e(e)); + sumT->p(e, sumT->e(e) * listOfTensors.at(i)->e(e)); } }; samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); } else { idx = indices->e(i); - sumT = listOfOutTensors->at(idx); - sumT->assign(listOfTensors->at(i)); + sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(i)); } } - delete listOfTensors; - delete listOfOutTensors; } } @@ -380,24 +371,23 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); T maxVal = DataTypeUtils::max(); output->assign(-maxVal); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto maxT = listOfTensors->at(fi->second.at(idx)); + auto maxT = listOfTensors.at(fi->second.at(idx)); for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { T val = nd4j::math::nd4j_max(maxT->e(e), outputT->e(e)); outputT->p(e, val); } } - //outputT->assign(maxT); } } } @@ -433,17 +423,17 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); T maxVal = DataTypeUtils::max(); output->assign(maxVal); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto minT = listOfTensors->at(fi->second.at(idx)); + auto minT = listOfTensors.at(fi->second.at(idx)); for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { outputT->t(e) = nd4j::math::nd4j_min(minT->t(e), outputT->t(e)); @@ -485,17 +475,17 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); // FIXME: parallelism here? for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); Nd4jLong loopSize = fi->second.size(); for (Nd4jLong idx = 1; idx < loopSize; ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *outputT += *current; } (*outputT) /= double(fi->second.size()); @@ -524,17 +514,17 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); Nd4jLong loop_size = fi->second.size(); // FIXME: parallelism here? for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *(outputT) += *current; } //outputT->assign(maxT); @@ -564,14 +554,14 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *outputT *= *current; } @@ -603,14 +593,14 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors->at(fi->first); - outputT->assign(listOfTensors->at(fi->second.at(0))); + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors->at(fi->second.at(idx)); + auto current = listOfTensors.at(fi->second.at(idx)); *outputT += *current; } //outputT->assign(maxT); @@ -630,14 +620,14 @@ namespace helpers { //int numOfClasses = gradOut->sizeAt(0); // if input is a vector: (as if in doc sample) auto tempRes = gradOut->dup(); - segmentMaxFunctor_(input, indices, tempRes); + segmentMaxFunctor_(input, indices, &tempRes); if (input->isVector()) { Nd4jLong loop_size = input->lengthOf(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->e(classNum) - input->e(e)) <= T(1.e-6)) + if (nd4j::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= T(1.e-6)) output->p(e, gradOut->e(classNum)); } }; @@ -646,23 +636,23 @@ namespace helpers { else { std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //int numOfClasses = tempRes->sizeAt(0); // number of classes + //int numOfClasses = tempRes.sizeAt(0); // number of classes //std::vector> outputs(numOfClasses); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (uint64_t e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) currentOut->p(e, currentGradOut->e(e)); } } @@ -670,7 +660,7 @@ namespace helpers { samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); } - delete tempRes; + return ND4J_STATUS_OK; } @@ -681,13 +671,13 @@ namespace helpers { // segmen min int segmentMinFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - std::unique_ptr tempRes(gradOut->dup()); - segmentMinFunctor(context, input, indices, tempRes.get()); + NDArray tempRes = gradOut->dup(); + segmentMinFunctor(context, input, indices, &tempRes); if (input->isVector()) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->e(classNum) - input->e(e)) < 1.e-5) + if (nd4j::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) output->p(e, gradOut->e(classNum)); } }; @@ -696,12 +686,12 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //int numOfClasses = tempRes->sizeAt(0); // number of classes + //int numOfClasses = tempRes.sizeAt(0); // number of classes //std::vector> outputs(numOfClasses); output->assign(0.); int pos = 0; @@ -709,12 +699,12 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->e(e) - current->e(e)) < + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) currentOut->p(e, currentGradOut->e(e)); } @@ -749,20 +739,18 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); - - //int numOfClasses = tempRes->sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); +; int pos = 0; //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { currentOut->p(e, currentGradOut->e(e) / classCount.at(classNum)); @@ -788,16 +776,16 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); currentOut->assign(currentGradOut); } @@ -810,31 +798,31 @@ namespace helpers { int segmentProdFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { auto tempRes = gradOut->dup(); - segmentProdFunctor(context, input, indices, tempRes); + segmentProdFunctor(context, input, indices, &tempRes); if (input->isVector()) { for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes->e(classNum)/ input->e(e)); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum)/ input->e(e)); } } else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //int numOfClasses = tempRes->sizeAt(0); // number of classes + //int numOfClasses = tempRes.sizeAt(0); // number of classes //std::vector> outputs(numOfClasses); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); - auto currentFFOut = listOfBPTensors->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); } @@ -842,7 +830,7 @@ namespace helpers { //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - delete tempRes; + return ND4J_STATUS_OK; } @@ -855,35 +843,35 @@ namespace helpers { // int numOfClasses = gradOut->sizeAt(0); // if input is a vector: (as if in doc sample) auto tempRes = gradOut->dup(); - unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, tempRes); + unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); if (input->isVector()) { for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { Nd4jLong classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->e(classNum) - input->e(e)) < 1.e-5) + if (nd4j::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) output->p(e, gradOut->e(classNum)); } } else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (int i = 0; i < indices->lengthOf(); i++) { Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors->at(i); - NDArray* currentOut = listOfOutTensors->at(i); - NDArray* currentGradOut = listOfGradOuts->at(classNum); + NDArray* current = listOfTensors.at(i); + NDArray* currentOut = listOfOutTensors.at(i); + NDArray* currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->e(e) - current->e(e)) < 1.e-5) + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) currentOut->p(e, currentGradOut->e(e)); } } } - delete tempRes; + return ND4J_STATUS_OK; } @@ -895,13 +883,13 @@ namespace helpers { template static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { auto tempRes = gradOut->dup(); - unsortedSegmentMinFunctor(context, input, indices, numOfClasses, tempRes); + unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); if (input->isVector()) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - if (nd4j::math::nd4j_abs(tempRes->t(classNum) - input->t(e)) < 1.e-6) + if (nd4j::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) output->t(e) = gradOut->t(classNum); } }; @@ -911,20 +899,20 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { - if (nd4j::math::nd4j_abs(listOfBPTensors->at(classNum)->t(e) - current->t(e)) < 1.e-6) + if (nd4j::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) currentOut->t(e) = currentGradOut->t(e); } } @@ -932,7 +920,7 @@ namespace helpers { //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - delete tempRes; + return ND4J_STATUS_OK; } @@ -963,15 +951,15 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); for (int i = 0; i < indices->lengthOf(); i++) { Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors->at(i); - NDArray* currentOut = listOfOutTensors->at(i); - NDArray* currentGradOut = listOfGradOuts->at(classNum); + NDArray* current = listOfTensors.at(i); + NDArray* currentOut = listOfOutTensors.at(i); + NDArray* currentGradOut = listOfGradOuts.at(classNum); currentOut->assign(*currentGradOut / double(classCount[classNum])); } } @@ -991,15 +979,15 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); currentOut->assign(currentGradOut); } @@ -1011,14 +999,14 @@ namespace helpers { } int unsortedSegmentProdFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto tempRes = gradOut->dup(); - unsortedSegmentProdFunctor(context, input, indices, numOfClasses, tempRes); + auto tempRes = gradOut->dup(); + unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); if (input->isVector()) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes->e(classNum) / input->e(e)); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / input->e(e)); } }; @@ -1027,18 +1015,18 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfBPTensors(tempRes->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); - auto currentFFOut = listOfBPTensors->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); } @@ -1046,7 +1034,7 @@ namespace helpers { //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - delete tempRes; + return Status::OK(); } @@ -1076,16 +1064,16 @@ namespace helpers { else { auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - std::unique_ptr listOfGradOuts(gradOut->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfTensors(input->allTensorsAlongDimension(restDims)); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension(restDims)); + ResultSet listOfGradOuts =gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors =input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors =output->allTensorsAlongDimension(restDims); //auto func = PRAGMA_THREADS_FOR { for (auto i = 0; i < indices->lengthOf(); i++) { auto classNum = indices->e(i); - auto current = listOfTensors->at(i); - auto currentOut = listOfOutTensors->at(i); - auto currentGradOut = listOfGradOuts->at(classNum); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); for (int e = 0; e < current->lengthOf(); e++) { currentOut->p(e, currentGradOut->e(e) / nd4j::math::nd4j_sqrt(classCount[classNum])); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp index 7a9b77b66..4b54c7362 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -29,7 +29,7 @@ namespace nd4j { return x >> shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -42,7 +42,7 @@ namespace nd4j { return x << shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -56,7 +56,7 @@ namespace nd4j { return x >> shift | x << step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -70,7 +70,7 @@ namespace nd4j { return x << shift | x >> step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp index 1fea14824..642dd37da 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp @@ -34,7 +34,7 @@ static FORCEINLINE NDArray activation(const NDArray& arr) { // return (const_cast&>(arr)).template transform>(); auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, &result); + (const_cast(arr)).applyTransform(transform::Tanh, result); return result; } @@ -125,7 +125,7 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] @@ -212,7 +212,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K] @@ -306,7 +306,7 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr samediff::Threads::parallel_tad(func, 0, ncols); // gradB - gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K] + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] // gradW x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS] diff --git a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp index b974a236b..db9b6afff 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp @@ -47,15 +47,13 @@ static void stack_(const std::vector& inArrs, NDArray* outArr, c std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim}); auto list = outArr->allTensorsAlongDimension(dimsToExclude); // list.size() == block.width() - int listSize = list->size(); + int listSize = list.size(); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) - list->at(i)->assign(inArrs[i]); + list.at(i)->assign(inArrs[i]); }; samediff::Threads::parallel_tad(func, 0, listSize); - - delete list; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index 35615287b..9d755f6b6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -221,26 +221,26 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const T almostZero = DataTypeUtils::min(); T maxElem; if(len == 1) - maxElem = math::nd4j_abs(diagInterval->template e(0)); + maxElem = math::nd4j_abs(diagInterval.template e(0)); else - maxElem = (*diagInterval)({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); + maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - if(diagInterval->template e(0) < epsBig) - diagInterval->p(Nd4jLong(0), epsBig); + if(diagInterval.template e(0) < epsBig) + diagInterval.p(Nd4jLong(0), epsBig); for(int i=1; i < len; ++i) if(math::nd4j_abs(colVec0->template e(i)) < eps) colVec0->p(i, 0.f); for(int i=1; i < len; i++) - if(diagInterval->template e(i) < epsBig) { + if(diagInterval.template e(i) < epsBig) { deflation1(col1, shift, i, len); for(int i = 0; i < len; ++i) - diagInterval->p(i, _m.e(col1+shift+i,col1+shift+i)); + diagInterval.p(i, _m.e(col1+shift+i,col1+shift+i)); } { @@ -259,7 +259,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh int p = 1; for(int i=1; i(diagInterval->template e(i)) < almostZero) + if(math::nd4j_abs(diagInterval.template e(i)) < almostZero) permut[p++] = i; int k = 1, m = ind+1; @@ -269,7 +269,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh permut[p] = m++; else if(m >= len) permut[p] = k++; - else if(diagInterval->template e(k) < diagInterval->template e(m)) + else if(diagInterval.template e(k) < diagInterval.template e(m)) permut[p] = m++; else permut[p] = k++; @@ -279,7 +279,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if(totDefl) { for(int i=1; i(diagInterval->template e(ki)) < almostZero || diagInterval->template e(0) < diagInterval->template e(ki)) + if(math::nd4j_abs(diagInterval.template e(ki)) < almostZero || diagInterval.template e(0) < diagInterval.template e(ki)) permut[i-1] = permut[i]; else { permut[i-1] = 0; @@ -301,10 +301,10 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const int ki = permut[len - (totDefl ? i+1 : i)]; const int jac = tCol[ki]; - T _e0 = diagInterval->template e(jac); + T _e0 = diagInterval.template e(jac); //math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); - diagInterval->p(jac, diagInterval->template e(i)); - diagInterval->p(i, _e0); + diagInterval.p(jac, diagInterval.template e(i)); + diagInterval.p(i, _e0); if(i!=0 && jac!=0) { _e0 = colVec0->template e(jac); @@ -349,12 +349,12 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh { int i = len-1; - while(i > 0 && (math::nd4j_abs(diagInterval->template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) + while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) --i; for(; i > 1; --i) { - if( (diagInterval->template e(i) - diagInterval->template e(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval->template e(i) - diagInterval->template e(i-1)) >= epsBig) + if( (diagInterval.template e(i) - diagInterval.template e(i-1)) < DataTypeUtils::eps()*maxElem ) { + if (math::nd4j_abs(diagInterval.template e(i) - diagInterval.template e(i-1)) >= epsBig) throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); } @@ -362,7 +362,6 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh } delete colVec0; - delete diagInterval; } @@ -606,9 +605,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA const T almostZero = DataTypeUtils::min(); auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c'); - auto diag = *diagP; - delete diagP; + auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); diag.p(Nd4jLong(0), T(0)); singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); @@ -727,8 +724,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); temp.assign(0.); auto diag = _m.diagonal('c'); - (*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); - delete diag; + diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); return; } @@ -786,14 +782,10 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif temp.assign(_u({col1, col1+k+1, i, i+1}, true)); } - auto temp1 = _u({col1,col1+k+1, col1,col1+1}, true); - temp1.assign(q1 * c0); - auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true); - temp2.assign(q1 * (-s0)); - auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true); - temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0); - auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true); - temp4 *= c0; + _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); + _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); + _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0); + _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0; } else { @@ -841,8 +833,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); blockM = 0.f; auto diag = blockM.diagonal('c'); - diag->assign(singVals); - delete diag; + diag.assign(singVals); } ////////////////////////////////////////////////////////////////////////// @@ -958,16 +949,16 @@ static void svd_(const NDArray* x, const std::vector& outArrs, const b ResultSet* listU(nullptr), *listV(nullptr); if(calcUV) { - listU = u->allTensorsAlongDimension({rank-2, rank-1}); - listV = v->allTensorsAlongDimension({rank-2, rank-1}); + listU = new ResultSet(u->allTensorsAlongDimension({rank-2, rank-1})); + listV = new ResultSet(v->allTensorsAlongDimension({rank-2, rank-1})); } - for(int i = 0; i < listX->size(); ++i) { + for(int i = 0; i < listX.size(); ++i) { - // NDArray matrix(x->ordering(), {listX->at(i)->sizeAt(0), listX->at(i)->sizeAt(1)}, block.getContext()); - // matrix.assign(listX->at(i)); - helpers::SVD svdObj(*(listX->at(i)), switchNum, calcUV, calcUV, fullUV); - listS->at(i)->assign(svdObj._s); + // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), listX.at(i)->sizeAt(1)}, block.getContext()); + // matrix.assign(listX.at(i)); + helpers::SVD svdObj(*(listX.at(i)), switchNum, calcUV, calcUV, fullUV); + listS.at(i)->assign(svdObj._s); if(calcUV) { listU->at(i)->assign(svdObj._u); @@ -975,9 +966,6 @@ static void svd_(const NDArray* x, const std::vector& outArrs, const b } } - delete listX; - delete listS; - if(calcUV) { delete listU; delete listV; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp index 0fc6eea0b..481575297 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp @@ -30,7 +30,7 @@ namespace nd4j { return BitwiseUtils::flip_bits(_x); }; - in.applyLambda(lambda, &out); + in.applyLambda(lambda, out); } void __toggle_bits(nd4j::LaunchContext * context, NDArray& in, NDArray& out) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index ea2fb348a..ea5e90cd8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -39,7 +39,7 @@ template static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { auto dOdI = NDArray(&gradO); // dO/dI - const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), 'b', &dOdI); + const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), dOdI, 'b'); int dLen = dOdI.lengthOf(); auto func = PRAGMA_THREADS_FOR { @@ -66,11 +66,9 @@ static void trace_(const NDArray& input, NDArray& output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) - output.p(i, setOfSubArrs->at(i)->getTrace()); + output.p(i, setOfSubArrs.at(i)->getTrace()); }; - samediff::Threads::parallel_for(func, 0, setOfSubArrs->size()); - - delete setOfSubArrs; + samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); } void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output) { @@ -137,7 +135,7 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerato if(i == r) continue; - subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); + subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); } } else { @@ -149,20 +147,18 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerato //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) for(int i = firstDim - 1; i > 0; --i) { int r = rng.relativeInt(i) % i; - subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); if(r == 0) isZeroShuffled = true; if(i == r) continue; - subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i])); + subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); math::nd4j_swap(indices[i], indices[r]); } if(!isZeroShuffled) - subArrsListOut->at(0)->assign(subArrsListIn->at(0)); - delete subArrsListOut; + subArrsListOut.at(0)->assign(subArrsListIn.at(0)); } rng.rewindH(firstDim-1); - delete subArrsListIn; } } @@ -715,12 +711,10 @@ void eye(nd4j::LaunchContext * context, NDArray& output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) - arrs->at(i)->setIdentity(); + arrs.at(i)->setIdentity(); }; - samediff::Threads::parallel_tad(func, 0, arrs->size()); - - delete arrs; + samediff::Threads::parallel_tad(func, 0, arrs.size()); } ////////////////////////////////////////////////////////////////////////// @@ -752,25 +746,25 @@ void scatterUpdate(nd4j::LaunchContext * context, NDArray& input, NDArray& updat switch (opCode) { case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); break; case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); break; case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); break; case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); break; case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); break; case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); break; case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); break; default: continue; @@ -917,7 +911,7 @@ template static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { const int rank = input.rankOf(); - const auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions); + const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); const T normActual = norm2.e(0); const T normClip = clipNorm.e(0); @@ -937,12 +931,10 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& for (auto i = start; i < stop; i += increment) { const T iNormActual = norm2.e(i); if (iNormActual > normClip) - *listOfInSubArrs->at(i) *= normClip / iNormActual; + *listOfInSubArrs.at(i) *= normClip / iNormActual; } }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs->size()); - - delete listOfInSubArrs; + samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); } } else { @@ -961,8 +953,8 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { - auto inputSubArr = listOfInSubArrs->at(i); - auto outputSubArr = listOfOutSubArrs->at(i); + auto inputSubArr = listOfInSubArrs.at(i); + auto outputSubArr = listOfOutSubArrs.at(i); outputSubArr->assign(inputSubArr); const T iNormActual = norm2.e(i); @@ -971,10 +963,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& *outputSubArr *= clipNorm / iNormActual; } }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs->size()); - - delete listOfInSubArrs; - delete listOfOutSubArrs; + samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); } } } @@ -1021,7 +1010,7 @@ void clipByNorm(nd4j::LaunchContext * context, NDArray& input, NDArray& output, else { auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, output); + input->applyLambda(lambda, *output); } } } @@ -1037,7 +1026,7 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g const int rank = input.rankOf(); - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); if(norm2.lengthOf() == 1) { @@ -1055,16 +1044,16 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); }; - (const_cast(input)).applyPairwiseLambda(const_cast(&gradO), lambda, &gradI); + (const_cast(input)).applyPairwiseLambda(const_cast(gradO), lambda, gradI); } else gradI.assign(gradO); } else { - const auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); - const auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); - const auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); + auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); + auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); + auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); auto cn = clipNorm.e(0); @@ -1072,11 +1061,11 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g for (auto i = start; i < stop; i += increment) { T N = norm2.e(i); - auto gradOSubArr = gradOSubArrs->at(i); - auto gradISubArr = gradISubArrs->at(i); + auto gradOSubArr = gradOSubArrs.at(i); + auto gradISubArr = gradISubArrs.at(i); if (N > cn) { - auto inputSubArr = inputSubArrs->at(i); + auto inputSubArr = inputSubArrs.at(i); const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e(0); // reduce to scalar const T factor1 = static_cast(1.f) / N; const T factor3 = factor1 / (N * N); // 1 / (N*N*N) @@ -1085,16 +1074,12 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); }; - inputSubArr->applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); + inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); } else gradISubArr->assign(gradOSubArr); } }; - samediff::Threads::parallel_tad(func, 0, gradISubArrs->size()); - - delete gradISubArrs; - delete gradOSubArrs; - delete inputSubArrs; + samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); } } @@ -1120,25 +1105,24 @@ static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector(lambda, &output); + input.applyLambda(lambda, output); } } else { // along dimension - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions, false); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); if (!isInplace) output.assign(input); auto tads = output.allTensorsAlongDimension(dimensions); // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads->size(); e++) { - T n2 = norm2.e(e) / tads->at(e)->lengthOf(); + for (int e = 0; e < tads.size(); e++) { + T n2 = norm2.e(e) / tads.at(e)->lengthOf(); const T factor = cn / n2; if (n2 > cn) { auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads->at(e)->applyLambda(lambda, &output); + tads.at(e)->applyLambda(lambda, output); } } - delete tads; } } @@ -1164,7 +1148,7 @@ static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector(routine, &output); + input.applyLambda(routine, output); } void clipByValue(nd4j::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { diff --git a/libnd4j/include/ops/declarable/helpers/cross.h b/libnd4j/include/ops/declarable/helpers/cross.h index d087a4849..31b386e7e 100644 --- a/libnd4j/include/ops/declarable/helpers/cross.h +++ b/libnd4j/include/ops/declarable/helpers/cross.h @@ -65,23 +65,19 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND auto tadsB = b_.allTensorsAlongDimension({1}); auto tadsO = o_.allTensorsAlongDimension({1}); - int tads = tadsA->size(); + int tads = tadsA.size(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e += increment) { - auto a_ = tadsA->at(e); - auto b_ = tadsB->at(e); - auto o_ = tadsO->at(e); + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); helpers::cross(context, a_, b_, o_); } }; samediff::Threads::parallel_tad(func, 0, tads); - - delete tadsA; - delete tadsB; - delete tadsO; } void weightedCrossEntropyWithLogitsFunctor(nd4j::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu index fabe6800a..4c746f244 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu @@ -244,7 +244,7 @@ namespace helpers { return res; }; - input->applyTriplewiseLambda(gradX, epsilon, gainsInternal, output); + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index 21b2eecd4..b2a13bfce 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -345,9 +345,9 @@ void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& outpu BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); - // auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); + // auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); // (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - // auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + // auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); // output /= sumAlongDim; // input.tickReadDevice(); } @@ -463,11 +463,11 @@ void logSoftmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& ou } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; - output.applyTransform(transform::Log); + output.applyTransform(transform::Log, output); input.tickReadDevice(); } @@ -580,9 +580,9 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); + (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily + auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; output *= (1.f - output); // derivative input.tickReadDevice(); @@ -600,7 +600,7 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr auto routine = LAMBDA_T(_x, threshold) { return _x > (T)threshold ? _x: (T)0.f; }; - const_cast(input).applyLambda(routine, &output); + const_cast(input).applyLambda(routine, output); } void thresholdRelu(nd4j::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { @@ -611,7 +611,7 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr linkage void thresholdReluDerivative_(NDArray* input, double theta, NDArray* dLdO, NDArray* output) { auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - input->applyPairwiseLambda(dLdO, derivative, output); + input->applyPairwiseLambda(*dLdO, derivative, *output); } void thresholdReluDerivative(nd4j::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu index 450ac08cc..99fbd33a8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu @@ -45,21 +45,21 @@ void bgemm(const std::vector& vA, const std::vector& vB, std for(int i = 0; i < bS; ++i) { if(vA[i]->ews() != 1) { - pA[i] = vA[i]->dup('f'); + pA[i] = new NDArray(vA[i]->dup('f')); toDelete.emplace_back(pA[i]); } else pA[i] = vA[i]; if(vB[i]->ews() != 1) { - pB[i] = vB[i]->dup('f'); + pB[i] = new NDArray(vB[i]->dup('f')); toDelete.emplace_back(pB[i]); } else pB[i] = vB[i]; if(vC[i]->ews() != 1) { - pC[i] = vC[i]->dup('f'); + pC[i] = new NDArray(vC[i]->dup('f')); toDelete.emplace_back(pC[i]); } else diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 6b86ce302..4f77b2e7c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; } @@ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con NDArray* gradBR = gradB; if(gradB->rankOf() == 2) gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW if(gradBR != gradB) delete gradBR; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index aa47e3e88..cbdff509d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -49,8 +49,8 @@ namespace helpers { } return nd4j::math::nd4j_round(zeroPointFromMin); }(); - *nudgedMin = (quantMinF - nudgedZeroPoint) * (*scale); - *nudgedMax = (quantMaxF - nudgedZeroPoint) * (*scale); + *nudgedMax = (quantMaxF - static_cast(nudgedZeroPoint)) * (*scale); + *nudgedMin = (quantMinF - static_cast(nudgedZeroPoint)) * (*scale); } template @@ -75,7 +75,7 @@ namespace helpers { return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); }; - input->applyLambda(wiseMinMaxAndSoOn, output); + input->applyLambda(wiseMinMaxAndSoOn, *output); } template diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu index 9d0e5e55b..a12b43973 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu @@ -31,7 +31,7 @@ void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step return _x - (_y * weight); }; - input->applyPairwiseLambda(step, lambda, output); + input->applyPairwiseLambda(*step, lambda, *output); } void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu index cbbdf1439..82ab9d764 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu @@ -77,15 +77,15 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa // reset gate r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid); + r->applyTransform(transform::Sigmoid, *r); // update gate u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid); + u->applyTransform(transform::Sigmoid, *u); // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh); + c->applyTransform(transform::Tanh, *c); NDArray temp = 1.f - *c * *c; @@ -231,15 +231,15 @@ void gruCellBP(nd4j::LaunchContext* context, // reset gate NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid); + r.applyTransform(transform::Sigmoid, r); // update gate NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid); + u.applyTransform(transform::Sigmoid, u); // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh); + c.applyTransform(transform::Tanh, c); // h = (1 - u) * c + u * hPrev @@ -352,10 +352,10 @@ void gruCellBP(nd4j::LaunchContext* context, dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU] + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index a5d686dc2..bf6a943fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -48,13 +48,12 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* * In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call */ auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - auto targetIdx = indexMax->e(0); + auto targetIdx = indexMax.e(0); dim3 launchDims(128, 512, 1024); BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES); manager.synchronize(); - delete indexMax; } else { Nd4jLong* hostYShapeInfo = nullptr; Nd4jLong* hostTShapeInfo = nullptr; @@ -71,10 +70,8 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)); // at this point, all IMax indexes are gathered, and we execute filler - BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, indexMaxArr->specialBuffer(), output->specialBuffer(), output->specialShapeInfo(), packZ.specialShapeInfo(), dimension, dimensionLength, packZ.specialOffsets()), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, indexMaxArr.specialBuffer(), output->specialBuffer(), output->specialShapeInfo(), packZ.specialShapeInfo(), dimension, dimensionLength, packZ.specialOffsets()), LIBND4J_TYPES); manager.synchronize(); - - delete indexMaxArr; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index 753c8ae64..a3d24111a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -33,7 +33,7 @@ namespace nd4j { return x > (T) 0.f ? y : T(0.f); }; - theFirst->applyPairwiseLambda(theSecond, functor, nullptr); + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { @@ -46,7 +46,7 @@ namespace nd4j { return x > (T)0.f ? y : T(0.f); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -59,7 +59,7 @@ namespace nd4j { return x > (T)0.f && x < (T)6.f? y : T(0.f); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -75,7 +75,7 @@ namespace nd4j { return x < 0 ? alphaT * y : y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -91,7 +91,7 @@ namespace nd4j { return y * nd4j::math::nd4j_eluderivative(x, alphaT); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { @@ -104,7 +104,7 @@ namespace nd4j { return y * simdOps::SELUDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu index 3a09f9a80..afd07cd48 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu @@ -34,7 +34,7 @@ namespace nd4j { return y * ((T)1.0f - (th * th)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -49,7 +49,7 @@ namespace nd4j { return y * simdOps::HardTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -62,7 +62,7 @@ namespace nd4j { return y * simdOps::RationalTanhDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -75,7 +75,7 @@ namespace nd4j { return x > (T) 0.0f ? y * (nd4j::math::nd4j_tanhderivative(x)) : (T) 0.0f; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu index fa97a3de2..fb4a94abb 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu @@ -34,7 +34,7 @@ namespace helpers { return y * (3 * x * x); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,7 +50,7 @@ namespace helpers { return x > T(0.f)? y : -y; }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -66,7 +66,7 @@ namespace helpers { return nd4j::math::nd4j_max(x, (T)0.f) - x * y + nd4j::math::nd4j_log((T)1.f + nd4j::math::nd4j_exp(-nd4j::math::nd4j_abs(x))); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,7 +86,7 @@ namespace helpers { return static_cast(1.) - y - e / (static_cast(1.) + e); }; - logits->applyPairwiseLambda(labels, functor, output); + logits->applyPairwiseLambda(*labels, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { @@ -104,7 +104,7 @@ namespace helpers { return y * ((T) 1.0f / (ss * ss)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -120,7 +120,7 @@ namespace helpers { return y * (p / (p + 1.)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -138,7 +138,7 @@ namespace helpers { return y * (s * ((T) 1.0f - s)); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -151,7 +151,7 @@ namespace helpers { return y * simdOps::HardSigmoidDerivative::op(x, nullptr); }; - input->applyPairwiseLambda(epsilon, functor, output); + input->applyPairwiseLambda(*epsilon, functor, *output); } void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { @@ -162,24 +162,24 @@ namespace helpers { template linkage void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyTransform(transform::Exp, tempInput.get()); + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { axisVector.resize(axis->lengthOf()); for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } template linkage void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { // reduce along axis with - std::unique_ptr tempInput(input->dup()); - input->applyPairwiseTransform(pairwise::Subtract, subtrah, tempInput.get()); - tempInput->applyTransform(transform::Exp, nullptr, nullptr); + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); std::vector axisVector; if (axis != nullptr) { @@ -187,8 +187,8 @@ namespace helpers { for (size_t i = 0; i < axisVector.size(); ++i) axisVector[i] = axis->e(i); } - tempInput->reduceAlongDimension(reduce::Sum, output, axisVector); - output->applyTransform(transform::Log, nullptr, nullptr); + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,16 +223,16 @@ namespace helpers { if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(targets), mainRoutineT1, output); + const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); } else { std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); std::unique_ptr targetTensor(new NDArray(*targets)); *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(targets), targetTensor.get(), mainRoutineT2, output); + const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu index 3fc7ef0b7..2204c9189 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu @@ -85,7 +85,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue); + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); if(peephole) zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc @@ -98,7 +98,7 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] // if clipping projection is provided then projected cell output state is clipped by this value if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue); + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); } else ht->assign(&htNoPeepHole); @@ -165,30 +165,30 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast if(forgetBias != 0.0) zf += forgetBias; - zz.applyTransform(transform::Tanh, z); //z = tanh(zz) - zi.applyTransform(transform::Sigmoid, i); //i = sigmoid(zi) - zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf); + zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) + zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) + zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i auto temp = (*f) * (*cLast); *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, h); //h = tanh(c) + c->applyTransform(transform::Tanh, *h); //h = tanh(c) // if clipping value is provided then cell state is clipped by this value prior to the cell output activation if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue); + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); if(peephole) { // add peephole connections to output gate zot + ct*Wc auto prod = *c * (*Wco); zo += prod; } - zo.applyTransform(transform::Sigmoid, o); // o = sigmoid(zo) + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, h, y, nullptr); //y = o * h + c->applyTransform(transform::Tanh, *h); //h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 4e5d9e85e..ce1dc2e95 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -603,7 +603,7 @@ namespace helpers { output->assign(input); // fill up output tensor with zeros output->tickWriteDevice(); - permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &iota, permutationVectors, true, nullptr); + permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr); permutationVectors->tickWriteDevice(); auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); @@ -839,7 +839,7 @@ namespace helpers { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { if (!inplace) output->assign(input); - std::unique_ptr tempOutput(output->dup()); + auto tempOutput =output->dup(); cusolverDnHandle_t handle = nullptr; auto n = input->sizeAt(-1); auto n2 = n * n; @@ -849,9 +849,9 @@ namespace helpers { throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); } F **dArrayBatch = nullptr; - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), - {tempOutput->rankOf() - 2, - tempOutput->rankOf() - 1}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(), + {tempOutput.rankOf() - 2, + tempOutput.rankOf() - 1}); const Nd4jLong batchSize = packX.numberOfTads(); int *dInfoArray = nullptr; auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize); @@ -865,7 +865,7 @@ namespace helpers { } auto stream = context->getCudaStream(); fillBatchKernel << < 1, batchSize, 128, *stream >> > - (dArrayBatch, reinterpret_cast(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize); + (dArrayBatch, reinterpret_cast(tempOutput.specialBuffer()), packX.specialOffsets(), batchSize); status = cusolverDnSetStream(handle, *stream); if (CUSOLVER_STATUS_SUCCESS != status) { @@ -895,7 +895,7 @@ namespace helpers { throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); } adjustResultsKernel << < batchSize, n2, 128, *stream >> > - (reinterpret_cast(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); + (reinterpret_cast(tempOutput.specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); err = cudaFree(dArrayBatch); if (err) { @@ -908,9 +908,9 @@ namespace helpers { } if (!inplace) - output->assign(tempOutput.get()); + output->assign(tempOutput); else - input->assign(tempOutput.get()); + input->assign(tempOutput); NDArray::registerSpecialUse({output}, {input}); return Status::OK(); @@ -978,7 +978,7 @@ namespace helpers { cholesky(context, input, &tempOutput, false); auto outputBuf = output->dataBuffer()->specialAsT(); //reinterpret_cast(output->specialBuffer()); // + e * n2; // + e * n2; - auto inputBuf = tempOutput.dataBuffer()->specialAsT(); //reinterpret_cast(tempOutput->specialBuffer()); + auto inputBuf = tempOutput.dataBuffer()->specialAsT(); //reinterpret_cast(tempOutput.specialBuffer()); output->nullify(); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(), {tempOutput.rankOf() - 2, diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu index ca24b3466..5a95eeb83 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu @@ -59,7 +59,7 @@ namespace helpers { auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); - if (listOut->size() != listDiag->size()) { + if (listOut.size() != listDiag.size()) { nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); return ND4J_STATUS_VALIDATION; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu index a2aec252e..b93563de2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu @@ -43,10 +43,10 @@ namespace nd4j { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -61,7 +61,7 @@ namespace nd4j { else gradY->assign(0.0f); - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -71,8 +71,8 @@ namespace nd4j { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -81,22 +81,16 @@ namespace nd4j { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu index 75c73f96b..90142091f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu @@ -43,10 +43,10 @@ namespace nd4j { // PWT case case // X gradient - epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); // Y gradient - epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); } else if (y->isScalar()) { T s = y->e(0); @@ -61,7 +61,7 @@ namespace nd4j { else gradY->assign(0.0f); - epsNext->applyPairwiseLambda(x, lambdaS, gradX); + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); } else { // broadcast case @@ -71,8 +71,8 @@ namespace nd4j { auto targetShape = epsNext->getShapeAsVector(); - preX->tileToShape(targetShape); - preY->tileToShape(targetShape); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); @@ -81,22 +81,16 @@ namespace nd4j { auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); gradX->assign(sum); - delete sum; } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); gradY->assign(sum); - delete sum; } else gradY->assign(preY); - - - delete preX; - delete preY; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu index ccfbbf943..79c9024f5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu @@ -94,7 +94,7 @@ namespace helpers { shape::checkDimensions(inputRank, axis); auto tempArray = input.dup(input.ordering()); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(tempArray->getShapeInfo(), axis); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(tempArray.getShapeInfo(), axis); auto tadLength = shape::length(packX.primaryShapeInfo()); @@ -114,11 +114,9 @@ namespace helpers { } position = tadLength - position - 1; - percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>(tempArray->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), packX.numberOfTads(), tadLength, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), position); + percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>(tempArray.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), packX.numberOfTads(), tadLength, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), position); nd4j::DebugHelper::checkErrorCode(context->getCudaStream(), "percentile"); - - delete tempArray; } void percentile(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index 1c28b8f24..7014d6a50 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -82,8 +82,8 @@ namespace helpers { NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); - copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha)); - copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta)); + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index 90e15b21f..15335d57e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -181,27 +181,21 @@ namespace helpers { auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - for(int i = 0; i < inSubArrsSet->size(); ++i) { + for(int i = 0; i < inSubArrsSet.size(); ++i) { int numOfElemsToReverse = seqLengths->e(i); if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet->at(i)->assign(inSubArrsSet->at(i)); + outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); } else { - auto inInnerSet = inSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet->size(); ++j) - reverseArray(context, inInnerSet->at(j), outInnerSet->at(j), numOfElemsToReverse); - - delete inInnerSet; - delete outInnerSet; + auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + for(int j = 0; j < inInnerSet.size(); ++j) + reverseArray(context, inInnerSet.at(j), outInnerSet.at(j), numOfElemsToReverse); } } - delete inSubArrsSet; - delete outSubArrsSet; } - } void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index d843feeff..bc53946d3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -235,9 +235,9 @@ namespace helpers { for (size_t i = 0; i < axes.size(); i++) { int axe = axes[i]; if (axe == input->rankOf() - 1) { // last dimension - std::unique_ptr listOfTensors(output->allTensorsAlongDimension({axe})); - std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); - int fullLen = listOfTensors->size(); + ResultSet listOfTensors = output->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); int theShift = shifts[i]; // if (theShift > 0) { // theShift %= fullLen; @@ -246,7 +246,7 @@ namespace helpers { // theShift -= fullLen * (theShift / fullLen - 1); // } for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); + rollFunctorLinear(output->getContext(), listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); } } else { std::vector dims(input->rankOf() - axe - 1); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index cab6e50e7..9585642dd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -212,7 +212,7 @@ namespace nd4j { NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); +// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), row, classes); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu index 49d388b2a..8ba3d40ce 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -29,7 +29,7 @@ namespace nd4j { return x >> shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -42,7 +42,7 @@ namespace nd4j { return x << shift; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -56,7 +56,7 @@ namespace nd4j { return x >> shift | x << step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { @@ -70,7 +70,7 @@ namespace nd4j { return x << shift | x >> step; }; - input.applyLambda(lambda, &output); + input.applyLambda(lambda, output); } void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu index 5ce883a59..76530269c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu @@ -33,7 +33,7 @@ namespace helpers { static FORCEINLINE NDArray activation(const NDArray& arr) { // return (const_cast&>(arr)).template transform>(); auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, &result); + (const_cast(arr)).applyTransform(transform::Tanh, result); return result; } @@ -236,7 +236,7 @@ void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const ND // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] @@ -497,7 +497,7 @@ void sruBIBP(nd4j::LaunchContext* context, NDArray* x, const NDArray* w, const N // x = x * mask if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] @@ -522,7 +522,7 @@ void sruBIBP(nd4j::LaunchContext* context, NDArray* x, const NDArray* w, const N manager.synchronize(); // gradB - gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K] + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] // gradW x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS] diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index b39ebf81b..4d1b18eef 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -148,24 +148,24 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, ND std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); } if(S->ews() != 1) { - pS = S->dup('f'); + pS = new NDArray(S->dup('f')); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U->dup('f'); + pU = new NDArray(U->dup('f')); toDelete.push_back(pU); } if(pVT->ews() != 1 || pVT->ordering() == 'c') { - pVT = VT->dup('f'); + pVT = new NDArray(VT->dup('f')); toDelete.push_back(pVT); } } @@ -276,8 +276,8 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N if(A->rankOf() != 2) throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); - auto m = A->sizeAt(0); - auto n = A->sizeAt(1); + int m = A->sizeAt(0); + int n = A->sizeAt(1); const int minDim = m < n ? m : n; if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S)) @@ -297,33 +297,53 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N } NDArray* pA = const_cast(A); - NDArray* pS = S; - NDArray* pU = U; - NDArray* pV = V; + + const bool aForder = m == 1 || A->strideAt(0) == 1; + const bool aCorder = n == 1 || A->strideAt(1) == 1; + + const bool transA = !aForder && aCorder; + const bool dupA = !aForder && !aCorder; std::vector toDelete; - if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A->dup('f'); + if(dupA) { + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); } + NDArray* pS = S; + if(S->ews() != 1) { - pS = S->dup('f'); + pS = new NDArray(S->dup('f')); toDelete.push_back(pS); } + NDArray *pU(nullptr), *pV(nullptr); + + int lda = transA ? pA->strideAt(0) : pA->strideAt(1); + int ldu(transA ? n : m), ldv(transA ? m : n); + bool uForder(true), vForder(true); + if(calcUV) { - if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U->dup('f'); + pU = transA ? V : U; + pV = transA ? U : V; + + uForder = pU->sizeAt(0) == 1 || pU->strideAt(0) == 1; + vForder = pV->sizeAt(0) == 1 || pV->strideAt(0) == 1; + + if(!uForder) { + pU = new NDArray(pU->dup('f')); toDelete.push_back(pU); } - if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V->dup('f'); + if(!vForder) { + pV = new NDArray(pV->dup('f')); toDelete.push_back(pV); } + + ldu = pU->strideAt(1); + ldv = pV->strideAt(1); } // create cusolverDn handle @@ -353,19 +373,27 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; const int econ = !fullUV; - int lda(m), ldu(m), ldv(m); + if(transA) + math::nd4j_swap(m, n); - if(calcUV) { - ldu = pU->sizeAt(0); - ldv = pV->sizeAt(0); + // *** avoid bug in cuda API *** + void* nullPtr = nullptr; + NDArray* arrToAvoidBugInAPI = nullptr; + if(!calcUV && m != n) { + int maxDim = m > n ? m : n; + arrToAvoidBugInAPI = new NDArray('c', {maxDim, maxDim}, pA->dataType(), context); + nullPtr = arrToAvoidBugInAPI->getSpecialBuffer(); } + // ****************** + + NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); + status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); + status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -380,14 +408,12 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N PointersManager manager(context, "svdJcb"); - NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); - // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -399,13 +425,20 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, N NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - S->assign(pS); + if(S->ews() != 1) + S->assign(pS); if(calcUV) { - U->assign(pU); - V->assign(pV); + + if(!uForder) + U->assign(transA ? pV : pU); + if(!vForder) + V->assign(transA ? pU : pV); } + if(!calcUV && m != n) + delete arrToAvoidBugInAPI; + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; @@ -465,24 +498,24 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray* A, NDArray* std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A->dup('f'); + pA = new NDArray(A->dup('f')); toDelete.push_back(pA); } if(S->ews() != 1) { - pS = S->dup('f'); + pS = new NDArray(S->dup('f')); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U->dup('f'); + pU = new NDArray(U->dup('f')); toDelete.push_back(pU); } if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V->dup('f'); + pV = new NDArray(V->dup('f')); toDelete.push_back(pV); } } @@ -618,15 +651,12 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectorallTensorsAlongDimension({S->rankOf() - 1}); if(calcUV) { - tadsU = U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1}); - tadsV = V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1}); + tadsU = new ResultSet(U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1})); + tadsV = new ResultSet(V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1})); } - for (int i = 0; i < tadsX->size(); ++i) - svdJcb(context, tadsX->at(i), tadsS->at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); - - delete tadsX; - delete tadsS; + for (int i = 0; i < tadsX.size(); ++i) + svdJcb(context, tadsX.at(i), tadsS.at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); if(calcUV) { delete tadsU; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index 8c67cbf1b..bc1171efe 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -30,7 +30,7 @@ namespace nd4j { return ~_x;//eUtils::flip_bits(_x); }; - in.applyLambda(lambda, &out); + in.applyLambda(lambda, out); } BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu index 972013835..520a6115d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu @@ -251,7 +251,7 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con // we get top K values first if (k == 1) { - input->applyIndexReduce(indexreduce::IndexMax, indices, {input->rankOf() - 1}); + input->applyIndexReduce(indexreduce::IndexMax, *indices, {input->rankOf() - 1}); // copy values on specified indices topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 1a5a255ee..764b6abbf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -649,7 +649,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr int r = rng.relativeInt(i) % i; if(i != r) - subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); + subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); } } else { @@ -661,21 +661,19 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr for(int i = firstDim - 1; i > 0; --i) { int r = rng.relativeInt(i) % i; - subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); if(r == 0) isZeroShuffled = true; if(i != r) { - subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i])); + subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); math::nd4j_swap(indices[i], indices[r]); } } if(!isZeroShuffled) - subArrsListOut->at(0)->assign(subArrsListIn->at(0)); - delete subArrsListOut; + subArrsListOut.at(0)->assign(subArrsListIn.at(0)); } rng.rewindH(firstDim-1); - delete subArrsListIn; } NDArray::registerSpecialUse({&output}, {&input}); @@ -747,7 +745,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr template static void clipByNorm_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, NDArray const& clipNormA, const bool isInplace) { const int rank = input.rankOf(); - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); clipNormA.syncToHost(); //norm2.printBuffer("Norm2"); T const clipNorm = clipNormA.e(0); @@ -814,10 +812,10 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr globalNorm += l2norm * l2norm; } - globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = nd4j::math::nd4j_sqrt(globalNorm); + globalNorm.applyTransform(transform::Sqrt, globalNorm); // = nd4j::math::nd4j_sqrt(globalNorm); outputs[inputs.size()]->p(0, globalNorm); globalNorm.syncToHost(); - const T factor = clipNorm / globalNorm.e(0); + const T factor = static_cast(clipNorm) / globalNorm.e(0); for (size_t e = 0; e < inputs.size(); e++) { // all-reduce @@ -830,7 +828,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr else { auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, output); + input->applyLambda(lambda, *output); } } } @@ -848,7 +846,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr auto cn = clipNorm.e(0); if (dimensions.size() == 0) { // all-reduce - T n2 = input.reduceNumber(reduce::Norm2).e(0) / input.lengthOf(); + T n2 = input.reduceNumber(reduce::Norm2).e(0) / static_cast(input.lengthOf()); if (n2 <= cn) { if (!isInplace) output.assign(input); @@ -856,28 +854,26 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr else { const T factor = cn / n2; //auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - //input.applyLambda(lambda, &output); + //input.applyLambda(lambda, output); output.assign(input * factor); } } else { // along dimension - auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions, false); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); if (!isInplace) output.assign(input); auto tads = output.allTensorsAlongDimension(dimensions); auto outTads = output.allTensorsAlongDimension(dimensions); // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads->size(); e++) { - T n2 = norm2.e(e) / tads->at(e)->lengthOf(); + for (int e = 0; e < tads.size(); e++) { + T n2 = norm2.e(e) / static_cast(tads.at(e)->lengthOf()); const T factor = cn / n2; if (n2 > cn) { //auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads->at(e)->applyScalar(scalar::Multiply, factor, outTads->at(e));//applyLambda(lambda, &output); + tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda(lambda, &output); } } - delete tads; - delete outTads; } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp index 4fb32e2f8..a75298af6 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp @@ -46,7 +46,7 @@ namespace helpers { // nd4j::NDArray comp1 = *comp; for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); - if(result2 > 0) { + if(result2 > static_cast(0)) { if (output != nullptr) output->p(numResults, arg->e(i)); numResults++; @@ -59,7 +59,7 @@ namespace helpers { nd4j::NDArray arg1 = *arg; for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); - if(result2 > 0) { + if(result2 > static_cast(0)) { if (output != nullptr) output->p(numResults, arg->e(i)); numResults++; @@ -74,7 +74,7 @@ namespace helpers { //for comparison for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { T result2 = processElementCondition(mode, arg->e(i), compScalar.e(0)); - if(result2 > 0) { + if(result2 > static_cast(0)) { if (output != nullptr) output->p(numResults, arg->e(i)); numResults++; diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 528642bb6..2b65d0c8e 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -130,7 +130,7 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation if(params[2] != 0) - c->applyScalar(scalar::LstmClip, params[2]); + c->applyScalar(scalar::LstmClip, params[2], *c); // peephole connections for output gate if(Wp != nullptr) @@ -206,22 +206,22 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes - xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nIn] + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] if(h) - hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nOut] + hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] } else { dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis - xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [nIn] - h0Set = h0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] - c0Set = c0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] - ctSet = ct->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] if(h) - hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [nOut] + hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] if(ht) - htSet = ht->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } // loops diff --git a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp index 179c7efab..3c65f740d 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp @@ -42,7 +42,7 @@ void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx // ht is current cell output [bS x nU], that is at current time step t ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU] - ht->applyTransform(transform::Tanh); + ht->applyTransform(transform::Tanh, *ht); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/lstm.h b/libnd4j/include/ops/declarable/helpers/lstm.h index 91ca87738..9c0df2fa5 100644 --- a/libnd4j/include/ops/declarable/helpers/lstm.h +++ b/libnd4j/include/ops/declarable/helpers/lstm.h @@ -33,7 +33,7 @@ namespace helpers { } static FORCEINLINE void sigmoidInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Sigmoid); + (const_cast(arr)).applyTransform(transform::Sigmoid, const_cast(arr)); } ////////////////////////////////////////////////////////////////////////// @@ -42,7 +42,7 @@ namespace helpers { } static FORCEINLINE void tanhInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Tanh); + (const_cast(arr)).applyTransform(transform::Tanh, const_cast(arr)); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index d0bc16b66..a52c2c0e5 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -46,41 +46,41 @@ static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float switch (opId) { case 0: - (const_cast(x)).applyTransform(transform::Tanh, &z); + (const_cast(x)).applyTransform(transform::Tanh, z); break; case 1: - (const_cast(x)).applyScalar(scalar::RELU, 0, &z); + (const_cast(x)).applyScalar(scalar::RELU, 0, z); break; case 2: - (const_cast(x)).applyTransform(transform::Sigmoid, &z); + (const_cast(x)).applyTransform(transform::Sigmoid, z); break; case 3: { ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::Affine, &z, &args); + (const_cast(x)).applyTransform(transform::Affine, z, &args); break; } case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, &z); + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); break; case 5: helpers::thresholdRelu(x.getContext(), x, alpha, z); break; case 6: { ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::ScaledTanh, &z, &args); + (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); break; } case 7: - (const_cast(x)).applyTransform(transform::HardSigmoid, &z); + (const_cast(x)).applyTransform(transform::HardSigmoid, z); break; case 8: - (const_cast(x)).applyScalar(scalar::ELU, alpha, &z); + (const_cast(x)).applyScalar(scalar::ELU, alpha, z); break; case 9: - (const_cast(x)).applyTransform(transform::SoftSign, &z); + (const_cast(x)).applyTransform(transform::SoftSign, z); break; case 10: - (const_cast(x)).applyTransform(transform::SoftPlus, &z); + (const_cast(x)).applyTransform(transform::SoftPlus, z); break; default: throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index 3e35e2c11..040cde77c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -38,7 +38,7 @@ namespace nd4j { } LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = scalar.dup(scalar.ordering()); + _scalar = new NDArray(scalar.dup(scalar.ordering())); } ShapeList *LegacyScalarBoolOp::calculateOutputShape(ShapeList *inputShape, nd4j::graph::Context &block) { diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp index 581bdae4c..5622f4316 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp @@ -38,7 +38,7 @@ namespace nd4j { } LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = scalar.dup(scalar.ordering()); + _scalar = new NDArray(scalar.dup(scalar.ordering())); } ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, nd4j::graph::Context &block) { diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index e66589b0a..c7111cc7a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -339,38 +339,38 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // x - mean NDArray xMinusMean(x); // empty array with same shape as x - const_cast(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); + const_cast(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, *mean, xMinusMean); // stdInv NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 // dfdm / N - auto dfdm = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes); + auto dfdm = dLdO->reduceAlongDimension(nd4j::reduce::Sum, excludedAxes); dfdm *= stdInv; dfdm *= -Ninv; // dvdm / 2 NDArray dvdm(mean); // empty array with same shape as mean - xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, &dvdm, excludedAxes); + xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dvdm, excludedAxes); dvdm *= -Ninv; // (2/N)*dfdv NDArray dfdv(variance); // empty array with same shape as variance - (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, &dfdv, excludedAxes); + (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dfdv, excludedAxes); dfdv *= stdInv*stdInv*stdInv; dfdv *= -Ninv; // dvdm/2 + (x - m) - xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dvdm); + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dvdm, xMinusMean); // dfdv * (dvdm/2 + (x - m)) - xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &dfdv); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dfdv, xMinusMean); // add dfdm / N - xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dfdm); + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dfdm, xMinusMean); // * gamma auto gamma = (*weights)({0,1, 0,0}); - xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &gamma); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, gamma, xMinusMean); *dLdI += xMinusMean; } diff --git a/libnd4j/include/ops/impl/gemm.cpp b/libnd4j/include/ops/impl/gemm.cpp index 74b832b4a..a81c12818 100644 --- a/libnd4j/include/ops/impl/gemm.cpp +++ b/libnd4j/include/ops/impl/gemm.cpp @@ -100,7 +100,7 @@ namespace nd4j { } if (beta != 0.0) { - C[zIdx] = static_cast(dot + beta * C[zIdx]); + C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); } else { C[zIdx] = static_cast(dot); } @@ -134,8 +134,8 @@ namespace nd4j { int aIdx = linearIndexC(M, N, r, 0); auto aX = aT + aIdx; - auto dot = nd4j::math::nd4j_dot(aX, y, lda) * alpha; - z[r] = beta == 0.0f ? dot : dot + beta * z[r]; + auto dot = nd4j::math::nd4j_dot(aX, y, lda) * static_cast(alpha); + z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; } }; samediff::Threads::parallel_for(func, 0, M); diff --git a/libnd4j/include/ops/impl/specials.cpp b/libnd4j/include/ops/impl/specials.cpp index 11cca1b15..ad7f4060d 100644 --- a/libnd4j/include/ops/impl/specials.cpp +++ b/libnd4j/include/ops/impl/specials.cpp @@ -175,13 +175,13 @@ void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint PRAGMA_OMP_SIMD for (uint64_t i = 0; i < length; i++) { - z[i] /= n; + z[i] /= static_cast(n); } auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { for (Nd4jLong ar = 1; ar < n; ar++) { - z[i] += x[ar][i] / n; + z[i] += x[ar][i] / static_cast(n); } } }; @@ -201,7 +201,7 @@ void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { for (Nd4jLong ar = 0; ar < n; ar++) { - z[i] += x[ar][i] / n; + z[i] += x[ar][i] / static_cast(n); } } }; @@ -365,11 +365,11 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) if (hasBit) { if (hasSign) - dz[(e - 4) * 16 + bitId] -= threshold; + dz[(e - 4) * 16 + bitId] -= static_cast(threshold); else - dz[(e - 4) * 16 + bitId] += threshold; + dz[(e - 4) * 16 + bitId] += static_cast(threshold); } else if (hasSign) { - dz[(e - 4) * 16 + bitId] -= threshold / 2; + dz[(e - 4) * 16 + bitId] -= static_cast(threshold / 2); } } } @@ -423,13 +423,13 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) if (val < (T) 0.0f) { byte |= 1 << (bitId + 16); - dx[e] += threshold; + dx[e] += static_cast(threshold); } else { - dx[e] -= threshold; + dx[e] -= static_cast(threshold); } } else if (abs >= (T) threshold / (T) 2.0f && val < (T) 0.0f) { byte |= 1 << (bitId + 16); - dx[e] += threshold / 2; + dx[e] += static_cast(threshold / 2); retVal++; } diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 22bb87103..b4960bc90 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -1298,7 +1298,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::interval(0,131072), NDIndex::interval(0,1)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1322,7 +1322,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::interval(0,2*1024,2), NDIndex::all(), NDIndex::interval(0,1)}); - strided = arr->subarray(indices); + strided = new NDArray(arr->subarray(indices)); delete arr; } @@ -1358,7 +1358,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1393,7 +1393,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1418,7 +1418,7 @@ namespace nd4j { strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = arr->subarray(indices); //All rows, first column + strided = new NDArray(arr->subarray(indices)); //All rows, first column delete arr; } @@ -1565,7 +1565,7 @@ namespace nd4j { int r = p.getIntParam("rowcol"); auto arr = NDArrayFactory::create_('c', {r, r+1}); IndicesList indices({NDIndex::all(), NDIndex::interval(0,r-1)}); - auto view = arr->subarray(indices); + auto view = new NDArray(arr->subarray(indices)); //nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), view->sizeAt(1)); x.push_back(view); if(p.getIntParam("inplace") == 1){ diff --git a/libnd4j/include/types/bfloat16.h b/libnd4j/include/types/bfloat16.h index 9b8081495..847c2ebda 100644 --- a/libnd4j/include/types/bfloat16.h +++ b/libnd4j/include/types/bfloat16.h @@ -47,489 +47,221 @@ //{ struct bfloat16 { - public: - int16_t _data; - /* constexpr */ local_def bfloat16() { _data = 0; } + private: + template + struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - template - local_def /*explicit*/ bfloat16(const T& rhs) { - assign(rhs); - } + public: + int16_t _data; -// local_def bfloat16(float rhs) { -// assign(rhs); -// } -// -// local_def bfloat16(double rhs) { -// assign(rhs); -// } + local_def bfloat16() { + _data = 0; + } - local_def operator float() const { - int32_t temp = this->_data << 16; //((sign << 31) | (exponent << 23) | mantissa); + template ::value>::type> + local_def bfloat16(const T& rhs) { + *this = rhs; + } - return *reinterpret_cast(&temp); - } + local_def operator float() const { + int32_t temp = this->_data << 16; //((sign << 31) | (exponent << 23) | mantissa); + return *reinterpret_cast(&temp); + } - local_def explicit operator double() const { return static_cast(static_cast(*this)); } - local_def explicit operator unsigned long long() const { return static_cast(static_cast(*this)); } - local_def explicit operator int16_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator uint16_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator uint32_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator uint8_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator int8_t() const { return static_cast(static_cast(*this)); } - local_def explicit operator int() const { return static_cast(static_cast(*this)); } - local_def explicit operator Nd4jLong() const { return static_cast(static_cast(*this)); } - local_def explicit operator bool() const { return this->_data == 0 ? false : true; } - local_def explicit operator float16() const { return static_cast(static_cast(*this)); } + local_def explicit operator bool() const { + return this->_data == 0 ? false : true; + } - template - local_def bfloat16& operator=(const T& rhs) { assign(rhs); return *this; } + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } - local_def void assign(unsigned int rhs) { - // may be a better way ? - assign((float)rhs); - } + local_def bfloat16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f: 0.f; + return *this; + } - local_def void assign(int rhs) { - // may be a better way ? - assign((float)rhs); - } + local_def bfloat16& operator=(const float& rhs) { + #ifdef __CUDACC__ + if(::isnan(rhs)) { + _data = bfloat16::nan(); + return *this; + } + #endif + auto x = *reinterpret_cast(& const_cast(rhs)); + uint32_t lsb = (x >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + x += rounding_bias; + this->_data = static_cast(x >> 16); - local_def void assign(double rhs) { - assign((float)rhs); - } + return *this; + } - local_def void assign(long long rhs) { - assign((float)rhs); - } + local_def bfloat16& operator=(const bfloat16& rhs) { + _data = rhs._data; + return *this; + } - local_def void assign(long int rhs) { - assign((float)rhs); - } + template ::value>::type> + local_def bfloat16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } - local_def void assign(long unsigned int rhs) { - assign((float)rhs); - } + local_def friend bool operator==(const bfloat16& a, const bfloat16& b) { return (a._data == b._data); } + local_def friend bool operator!=(const bfloat16& a, const bfloat16& b) { return !(a == b); } + local_def friend bool operator<(const bfloat16& a, const bfloat16& b) { return (float)a < (float)b; } + local_def friend bool operator>(const bfloat16& a, const bfloat16& b) { return (float)a > (float)b; } + local_def friend bool operator<=(const bfloat16& a, const bfloat16& b) { return (float)a <= (float)b; } + local_def friend bool operator>=(const bfloat16& a, const bfloat16& b) { return (float)a >= (float)b; } - local_def void assign(unsigned short rhs) { - assign((float)rhs); - } + local_def friend bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a + (float)b); } + local_def friend bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a - (float)b); } + local_def friend bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a * (float)b); } + local_def friend bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a / (float)b); } - local_def void assign(float16 rhs) { - assign((float)rhs); - } + template ::value>::type> + local_def friend bfloat16 operator+(const bfloat16& a, const T& b) { return a + static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator+(const T& a, const bfloat16& b) { return static_cast(a) + b; } - local_def void assign(long long unsigned int rhs) { - assign((float)rhs); - } + template ::value>::type> + local_def friend bfloat16 operator-(const bfloat16& a, const T& b) { return a - static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator-(const T& a, const bfloat16& b) { return static_cast(a) - b; } - local_def void assign(float rhs) { -#ifdef __CUDACC__ - if(::isnan(rhs)) { - _data = bfloat16::nan(); - return; - } -#endif - auto x = *reinterpret_cast(&rhs); - uint32_t lsb = (x >> 16) & 1; - uint32_t rounding_bias = 0x7fff + lsb; - x += rounding_bias; - this->_data = static_cast(x >> 16); - } + template ::value>::type> + local_def friend bfloat16 operator*(const bfloat16& a, const T& b) { return a * static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator*(const T& a, const bfloat16& b) { return static_cast(a) * b; } - local_def void assign(const bfloat16& rhs) { - _data = rhs._data; - } + template ::value>::type> + local_def friend bfloat16 operator/(const bfloat16& a, const T& b) { return a / static_cast(b); } + template ::value>::type> + local_def friend bfloat16 operator/(const T& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16& operator+=(bfloat16 rhs) { assign((float)(*this) + (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator==(const bfloat16& a, const T& b) { return a == static_cast(b); } + template ::value>::type> + local_def friend bool operator==(const T& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bfloat16& operator-=(bfloat16 rhs) { assign((float)*this - (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator!=(const bfloat16& a, const T& b) { return a != static_cast(b); } + template ::value>::type> + local_def friend bool operator!=(const T& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bfloat16& operator*=(bfloat16 rhs) { assign((float)*this * (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator<(const bfloat16& a, const T& b) { return a < static_cast(b); } + template ::value>::type> + local_def friend bool operator<(const T& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bfloat16& operator/=(bfloat16 rhs) { assign((float)*this / (float)rhs); return *this; } + template ::value>::type> + local_def friend bool operator>(const bfloat16& a, const T& b) { return a > static_cast(b); } + template ::value>::type> + local_def friend bool operator>(const T& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bfloat16& operator+=(float rhs) { assign((float)*this + rhs); return *this; } + template ::value>::type> + local_def friend bool operator<=(const bfloat16& a, const T& b) { return a <= static_cast(b); } + template ::value>::type> + local_def friend bool operator<=(const T& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bfloat16& operator-=(float rhs) { assign((float)*this - rhs); return *this; } + template ::value>::type> + local_def friend bool operator>=(const bfloat16& a, const T& b) { return a >= static_cast(b); } + template ::value>::type> + local_def friend bool operator>=(const T& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bfloat16& operator*=(float rhs) { assign((float)*this * rhs); return *this; } + local_def bfloat16& operator+=(bfloat16 rhs) { *this = (float)(*this) + (float)rhs; return *this; } - local_def bfloat16& operator/=(float rhs) { assign((float)*this / rhs); return *this; } + local_def bfloat16& operator-=(bfloat16 rhs) { *this = (float)(*this) - (float)rhs; return *this; } - local_def bfloat16& operator++() { *this += 1.f; return *this; } + local_def bfloat16& operator*=(bfloat16 rhs) { *this = (float)(*this) * (float)rhs; return *this; } - local_def bfloat16& operator--() { *this -= 1.f; return *this; } + local_def bfloat16& operator/=(bfloat16 rhs) { *this = (float)(*this) / (float)rhs; return *this; } - local_def bfloat16 operator++(int i) { *this += i; return *this; } + template ::value>::type> + local_def bfloat16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } - local_def bfloat16 operator--(int i) { *this -= i; return *this; } + template ::value>::type> + local_def bfloat16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } - local_def std::ostream& operator<<(std::ostream& os) { - os << static_cast(*this); - return os; - } - local_def static bfloat16 min() { - bfloat16 res; - res._data = 0xFF7F; - return res; - } - local_def static bfloat16 max() { - bfloat16 res; - res._data = 0x7F7F; - return res; + template ::value>::type> + local_def bfloat16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } - } - local_def static bfloat16 eps() { - bfloat16 res; - res._data = 0x3C00; - return res; - } + template ::value>::type> + local_def bfloat16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } - local_def static bfloat16 inf() { - bfloat16 res; - res._data = 0x3C00; - return res; - } + local_def bfloat16& operator++() { *this = (float)*this + (float)1.f; return *this; } - local_def static bfloat16 nan() { - bfloat16 res; - res._data = 0x7FC0; - return res; - } - }; + local_def bfloat16& operator--() { *this = (float)*this - (float)1.f; return *this; } - local_def bool operator==(const bfloat16& a, const bfloat16& b) { return (a._data == b._data); } + local_def bfloat16 operator++(int) { *this = (float)*this + (float)1.f; return *this; } -// template -// local_def bool operator==(const bfloat16& a, const T& b) { return (a == (bfloat16) b); } + local_def bfloat16 operator--(int) { *this = (float)*this - (float)1.f; return *this; } - local_def bool operator!=(const bfloat16& a, const bfloat16& b) { return !(a == b); } -// - local_def bool operator<(const bfloat16& a, const bfloat16& b) { return (float)a < (float)b; } - - local_def bool operator>(const bfloat16& a, const bfloat16& b) { return (float)a > (float)b; } - - template - local_def bool operator>(const bfloat16& a, const T& b) { return (float)a > (float)b; } - - local_def bool operator<=(const bfloat16& a, const bfloat16& b) { return (float)a <= (float)b; } - template - local_def bool operator<=(const bfloat16& a, const T& b) { return (float)a <= (float)b; } - - local_def bool operator>=(const bfloat16& a, const bfloat16& b) { return (float)a >= (float)b; } - - local_def bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a + (float)b); } - local_def bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a - (float)b); } - local_def bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a * (float)b); } - local_def bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a / (float)b); } -// - - local_def bfloat16 operator+(const bfloat16& a, const double& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const float& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const float16& b) { return static_cast(static_cast(a) + static_cast(b)); } - local_def bfloat16 operator+(const float16& a, const bfloat16& b) { return static_cast(static_cast(a) + static_cast(b)); } - local_def bfloat16 operator+(const bfloat16& a, const int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const unsigned int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const long long& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const unsigned long long& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const long int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const bool& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const int8_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const uint8_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const int16_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const uint16_t& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const bfloat16& a, const long unsigned int& b) { return a + static_cast(b); } - local_def bfloat16 operator+(const int8_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const uint8_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const int16_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const uint16_t& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const bool& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const int& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const unsigned int& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const long long& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const unsigned long long& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const long int& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const float& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const double& a, const bfloat16& b) { return static_cast(a) + b; } - local_def bfloat16 operator+(const long unsigned int& a, const bfloat16& b) { return static_cast(a) + b; } - - local_def bfloat16 operator-(const bfloat16& a, const double& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const float& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const float16& b) { return static_cast(static_cast(a) - static_cast(b)); } - local_def bfloat16 operator-(const float16& a, const bfloat16& b) { return static_cast(static_cast(a) - static_cast(b)); } - local_def bfloat16 operator-(const bfloat16& a, const int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const unsigned int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const long long& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const unsigned long long& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const long int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const bool& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const int8_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const uint8_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const int16_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const uint16_t& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const bfloat16& a, const long unsigned int& b) { return a - static_cast(b); } - local_def bfloat16 operator-(const int8_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const uint8_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const int16_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const uint16_t& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const bool& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const int& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const unsigned int& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const long long& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const unsigned long long& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const long int& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const float& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const double& a, const bfloat16& b) { return static_cast(a) - b; } - local_def bfloat16 operator-(const long unsigned int& a, const bfloat16& b) { return static_cast(a) - b; } - - local_def bfloat16 operator/(const bfloat16& a, const double& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const float& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const float16& b) { return static_cast((float)a / (float)b); } - local_def bfloat16 operator/(const float16& a, const bfloat16& b) { return static_cast((float)a / (float)b); } - local_def bfloat16 operator/(const bfloat16& a, const int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const unsigned int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const long long& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const unsigned long long& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const long int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const bool& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const int8_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const uint8_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const int16_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const uint16_t& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const bfloat16& a, const long unsigned int& b) { return a / static_cast(b); } - local_def bfloat16 operator/(const int8_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const uint8_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const int16_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const uint16_t& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const bool& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const int& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const unsigned int& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const long long& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const unsigned long long& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const long int& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const float& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const double& a, const bfloat16& b) { return static_cast(a) / b; } - local_def bfloat16 operator/(const long unsigned int& a, const bfloat16& b) { return static_cast(a) / b; } - - local_def bfloat16 operator*(const bfloat16& a, const double& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const float& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const float16& b) { return static_cast((float)a * (float)b); } - local_def bfloat16 operator*(const float16& a, const bfloat16& b) { return static_cast((float)a * (float)b); } - local_def bfloat16 operator*(const bfloat16& a, const int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const unsigned int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const long long& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const unsigned long long& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const long int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const bool& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const int8_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const uint8_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const int16_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const uint16_t& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const bfloat16& a, const long unsigned int& b) { return a * static_cast(b); } - local_def bfloat16 operator*(const int8_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const uint8_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const int16_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const uint16_t& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const bool& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const int& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const unsigned int& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const long long& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const unsigned long long& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const long int& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const float& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const double& a, const bfloat16& b) { return static_cast(a) * b; } - local_def bfloat16 operator*(const long unsigned int& a, const bfloat16& b) { return static_cast(a) * b; } - - local_def bool operator==(const bfloat16& a, const float& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const float16& b) { return (float)a == (float)(b); } - local_def bool operator==(const bfloat16& a, const double& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const int& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const long long& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const unsigned long long& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const long int& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const int8_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const uint8_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const int16_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const uint16_t& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const bool& b) { return a == static_cast(b); } - local_def bool operator==(const bfloat16& a, const long unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const bool& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const int8_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint8_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const int16_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint16_t& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const int& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned int& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const long long& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned long long& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const long int& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const float& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const double& a, const bfloat16& b) { return static_cast(a) == b; } - local_def bool operator==(const long unsigned int& a, const bfloat16& b) { return static_cast(a) == b; } - - local_def bool operator!=(const bfloat16& a, const float& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const float16& b) { return (float)a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const double& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const int& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const long long& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const unsigned long long& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const long int& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const int8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const uint8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const int16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const uint16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const bool& b) { return a != static_cast(b); } - local_def bool operator!=(const bfloat16& a, const long unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const bool& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int8_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint8_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int16_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint16_t& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned int& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long long& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned long long& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long int& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const float& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const double& a, const bfloat16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long unsigned int& a, const bfloat16& b) { return static_cast(a) != b; } - - local_def bool operator<(const bfloat16& a, const float& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const float16& b) { return (float)a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const double& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const int& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const long long& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const unsigned long long& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const long int& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const int8_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const uint8_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const int16_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const uint16_t& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const bool& b) { return a < static_cast(b); } - local_def bool operator<(const bfloat16& a, const long unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const bool& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const int8_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint8_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const int16_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint16_t& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const int& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned int& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const long long& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned long long& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const long int& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const float& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const double& a, const bfloat16& b) { return static_cast(a) < b; } - local_def bool operator<(const long unsigned int& a, const bfloat16& b) { return static_cast(a) < b; } - - local_def bool operator>(const bfloat16& a, const float& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const float16& b) { return (float)a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const double& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const int& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const long long& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const unsigned long long& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const long int& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const int8_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const uint8_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const int16_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const uint16_t& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const bool& b) { return a > static_cast(b); } - local_def bool operator>(const bfloat16& a, const long unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const bool& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const int8_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint8_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const int16_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint16_t& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const int& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned int& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const long long& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned long long& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const long int& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const float& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const double& a, const bfloat16& b) { return static_cast(a) > b; } - local_def bool operator>(const long unsigned int& a, const bfloat16& b) { return static_cast(a) > b; } - - local_def bool operator<=(const bfloat16& a, const float& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const float16& b) { return (float)a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const double& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const unsigned long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const long int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const int8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const uint8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const int16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const uint16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const bool& b) { return a <= static_cast(b); } - local_def bool operator<=(const bfloat16& a, const long unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bool& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int8_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint8_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int16_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint16_t& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned int& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long long& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned long long& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long int& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const float& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const double& a, const bfloat16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long unsigned int& a, const bfloat16& b) { return static_cast(a) <= b; } - - local_def bool operator>=(const bfloat16& a, const float& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const float16& b) { return (float)a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const double& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const unsigned long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const long int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const int8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const uint8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const int16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const uint16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const bool& b) { return a >= static_cast(b); } - local_def bool operator>=(const bfloat16& a, const long unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bool& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int8_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint8_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int16_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint16_t& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned int& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long long& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned long long& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long int& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const float& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const double& a, const bfloat16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long unsigned int& a, const bfloat16& b) { return static_cast(a) >= b; } - - - local_def std::ostream& operator<<(std::ostream &os, const bfloat16 &f) { - os << static_cast(f); - return os; - } + local_def bfloat16 operator-() const { + return 0.f - (float)*this; + } - local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; } - local_def bfloat16 operator - (const bfloat16& h) { - auto temp = h._data; - temp ^= 0x8000; - bfloat16 t; - t._data = temp; - return t; -} + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } + local_def static bfloat16 min() { + bfloat16 res; + res._data = 0xFF7F; + return res; + } + local_def static bfloat16 max() { + bfloat16 res; + res._data = 0x7F7F; + return res; + + } + local_def static bfloat16 eps() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 inf() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 nan() { + bfloat16 res; + res._data = 0x7FC0; + return res; + } +}; + + + +// local_def std::ostream& operator<<(std::ostream &os, const bfloat16 &f) { +// os << static_cast(f); +// return os; +// } + + +// local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; } + +// local_def bfloat16 operator - (const bfloat16& h) { +// auto temp = h._data; +// temp ^= 0x8000; +// bfloat16 t; +// t._data = temp; +// return t; +// } // WARNING: this implementation only for avoid cyclic references between float16 and bfloat16 types. -local_def void float16::assign(const bfloat16& rhs) { - assign((float)rhs); -} +// local_def void float16::assign(const bfloat16& rhs) { +// assign((float)rhs); +// } //} // namespace diff --git a/libnd4j/include/types/float16.h b/libnd4j/include/types/float16.h index 0cc75daed..4aa0d5d66 100644 --- a/libnd4j/include/types/float16.h +++ b/libnd4j/include/types/float16.h @@ -25,7 +25,6 @@ #include #endif - struct bfloat16; #ifdef __CUDACC__ @@ -224,505 +223,258 @@ local_def ihalf cpu_float2ihalf_rn(float f) } #endif - struct float16 - { - public: - ihalf data; - local_def float16() { *data.getXP() = 0; } +struct float16 { - template - local_def float16(const T& rhs) { - assign(rhs); - } + private: + template + struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; };// || std::is_same::value; }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - local_def operator float() const { -#ifdef __CUDA_ARCH__ - return __half2float(data); -#else - return cpu_ihalf2float(data); -#endif - } + public: + ihalf data; + local_def float16() { *data.getXP() = 0; } - local_def explicit operator double() const { - return static_cast(static_cast(*this)); - } + template ::value || std::is_same::value>::type> + local_def float16(const T& rhs) { + *this = rhs; + } - local_def explicit operator Nd4jLong() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator int() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator bool() const { - return static_cast(*this) > 0.0f; - } - - local_def explicit operator int16_t() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator uint16_t() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator uint8_t() const { - return static_cast(static_cast(*this)); - } - - local_def explicit operator int8_t() const { - return static_cast(static_cast(*this)); - } - - local_def operator half() const { return data; } - - template - local_def float16& operator=(const T& rhs) { assign(rhs); return *this; } - - local_def void assign(unsigned int rhs) { - assign((float)rhs); - } - - local_def void assign(int rhs) { - assign((float)rhs); - } - - local_def void assign(double rhs) { - assign((float)rhs); - } - - local_def void assign(long long rhs) { - assign((float)rhs); - } - - local_def void assign(long int rhs) { - assign((float)rhs); - } - - local_def void assign(const bool rhs) { - assign(rhs ? 1.0f : 0.0f); - } - - local_def void assign(long unsigned int rhs) { - assign((float)rhs); - } - - local_def void assign(unsigned short rhs) { - *data.getXP() = rhs; - } - - local_def void assign(long long unsigned int rhs) { - assign((float)rhs); - } - - local_def void assign(float rhs) { -#ifdef __CUDA_ARCH__ - auto t = __float2half_rn(rhs); - auto b = *(data.getXP()); - -#ifdef CUDA_8 - *(data.getXP()) = t; -#else - data.assign(t); -#endif - -#else - data = cpu_float2ihalf_rn(rhs); -#endif - } - - local_def void assign(const ihalf& rhs) { - *data.getXP() = ((ihalf) rhs).getX(); - } - - local_def void assign(const bfloat16& rhs); - -#ifdef __CUDACC__ - local_def void assign(const half& rhs) { - data.assign(rhs); - } -#endif - - local_def void assign(const float16& rhs) { - data = rhs.data; - } - - local_def float16& operator+=(float16 rhs) { assign((float)*this + rhs); return *this; } - - local_def float16& operator-=(float16 rhs) { assign((float)*this - rhs); return *this; } - - local_def float16& operator*=(float16 rhs) { assign((float)*this * rhs); return *this; } - - local_def float16& operator/=(float16 rhs) { assign((float)*this / rhs); return *this; } - - local_def float16& operator+=(float rhs) { assign((float)*this + rhs); return *this; } - - local_def float16& operator-=(float rhs) { assign((float)*this - rhs); return *this; } - - local_def float16& operator*=(float rhs) { assign((float)*this * rhs); return *this; } - - local_def float16& operator/=(float rhs) { assign((float)*this / rhs); return *this; } - - local_def float16& operator++() { assign(*this + 1.f); return *this; } - - local_def float16& operator--() { assign(*this - 1.f); return *this; } - - local_def float16 operator++(int i) { assign(*this + (float)i); return *this; } - - local_def float16 operator--(int i) { assign(*this - (float)i); return *this; } - - local_def std::ostream& operator<<(std::ostream& os) { - os << static_cast(*this); - return os; - } - }; - - -#ifdef NATIVE_HALFS - local_def bool operator==(const float16& a, const float16& b) { return __hequ(a.data, b.data); } -#else - local_def bool operator==(const float16& a, const float16& b) { return ishequ_(((ihalf) a.data).getX(), ((ihalf)b.data).getX()); } -#endif - -#ifdef NATIVE_HALFS - local_def bool operator!=(const float16& a, const float16& b) { return !(__hequ(a.data, b.data)); } -#else - local_def bool operator!=(const float16& a, const float16& b) { return !(a == b); } -#endif - -#ifdef NATIVE_HALFS - local_def bool operator<(const float16& a, const float16& b) { return __hlt(a.data, b.data); } -#else - local_def bool operator<(const float16& a, const float16& b) { return (float)a < (float)b; } -#endif - -#ifdef NATIVE_HALFS - local_def bool operator>(const float16& a, const float16& b) { return __hgt(a.data, b.data); } -#else - local_def bool operator>(const float16& a, const float16& b) { return (float)a > (float)b; } -#endif - - template - local_def bool operator>(const float16& a, const T& b) { return (float)a > (float)b; } - -#ifdef NATIVE_HALFS - local_def bool operator<=(const float16& a, const float16& b) { return __hle(a.data, b.data); } -#else - local_def bool operator<=(const float16& a, const float16& b) { return (float)a <= (float)b; } -#endif - template - local_def bool operator<=(const float16& a, const T& b) { return (float)a <= (float)b; } - -#ifdef NATIVE_HALFS - local_def bool operator>=(const float16& a, const float16& b) { return __hge(a.data, b.data); } -#else - local_def bool operator>=(const float16& a, const float16& b) { return (float)a >= (float)b; } -#endif - -#ifdef NATIVE_HALFS - local_def float16 operator+(const float16& a, const float16& b) { return __hadd(a.data, b.data); } - - local_def float16 operator-(const float16& a, const float16& b) { return __hsub(a.data, b.data); } - - local_def float16 operator*(const float16& a, const float16& b) { return __hmul(a.data, b.data); } - - local_def float16 operator/(const float16& a, const float16& b) { - #ifdef CUDA_8 - return hdiv(a.data, b.data); - #else - return __hdiv(a.data, b.data); + local_def float16(const half& rhs) { + #ifdef __CUDACC__ + data.assign(rhs); #endif - } -#else - local_def float16 operator+(const float16& a, const float16& b) { return float16((float)a + (float)b); } - local_def float16 operator-(const float16& a, const float16& b) { return float16((float)a - (float)b); } - local_def float16 operator*(const float16& a, const float16& b) { return float16((float)a * (float)b); } - local_def float16 operator/(const float16& a, const float16& b) { return float16((float)a / (float)b); } -#endif + } - local_def float16 operator+(const float16& a, const double& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const float& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const int& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const unsigned int& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const long long& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const unsigned long long& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const long int& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const bool& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const int8_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const uint8_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const int16_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const uint16_t& b) { return a + static_cast(b); } - local_def float16 operator+(const float16& a, const long unsigned int& b) { return a + static_cast(b); } - local_def float16 operator+(const int8_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const uint8_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const int16_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const uint16_t& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const bool& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const int& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const unsigned int& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const long long& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const unsigned long long& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const long int& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const float& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const double& a, const float16& b) { return static_cast(a) + b; } - local_def float16 operator+(const long unsigned int& a, const float16& b) { return static_cast(a) + b; } + local_def operator float() const { + #ifdef __CUDA_ARCH__ + return __half2float(data); + #else + return cpu_ihalf2float(data); + #endif + } - local_def float16 operator-(const float16& a, const double& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const float& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const int& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const unsigned int& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const long long& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const unsigned long long& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const long int& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const bool& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const int8_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const uint8_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const int16_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const uint16_t& b) { return a - static_cast(b); } - local_def float16 operator-(const float16& a, const long unsigned int& b) { return a - static_cast(b); } - local_def float16 operator-(const int8_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const uint8_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const uint16_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const int16_t& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const bool& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const int& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const unsigned int& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const long long& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const unsigned long long& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const long int& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const float& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const double& a, const float16& b) { return static_cast(a) - b; } - local_def float16 operator-(const long unsigned int& a, const float16& b) { return static_cast(a) - b; } + local_def explicit operator bool() const { + return static_cast(*this) != 0.0f; + } - local_def float16 operator/(const float16& a, const double& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const float& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const int& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const unsigned int& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const long long& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const unsigned long long& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const long int& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const bool& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const int8_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const uint8_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const int16_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const uint16_t& b) { return a / static_cast(b); } - local_def float16 operator/(const float16& a, const long unsigned int& b) { return a / static_cast(b); } - local_def float16 operator/(const int8_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const uint8_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const uint16_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const int16_t& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const bool& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const int& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const unsigned int& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const long long& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const unsigned long long& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const long int& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const float& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const double& a, const float16& b) { return static_cast(a) / b; } - local_def float16 operator/(const long unsigned int& a, const float16& b) { return static_cast(a) / b; } - - local_def float16 operator*(const float16& a, const double& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const float& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const int& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const unsigned int& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const long long& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const unsigned long long& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const long int& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const bool& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const int8_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const uint8_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const int16_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const uint16_t& b) { return a * static_cast(b); } - local_def float16 operator*(const float16& a, const long unsigned int& b) { return a * static_cast(b); } - local_def float16 operator*(const int8_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const uint8_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const uint16_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const int16_t& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const bool& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const int& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const unsigned int& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const long long& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const unsigned long long& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const long int& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const float& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const double& a, const float16& b) { return static_cast(a) * b; } - local_def float16 operator*(const long unsigned int& a, const float16& b) { return static_cast(a) * b; } + local_def explicit operator half() const { + return data; + } - local_def bool operator==(const float16& a, const float& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const double& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const int& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const long long& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const unsigned long long& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const long int& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const int8_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const uint8_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const int16_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const uint16_t& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const bool& b) { return a == static_cast(b); } - local_def bool operator==(const float16& a, const long unsigned int& b) { return a == static_cast(b); } - local_def bool operator==(const bool& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const int8_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint8_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const uint16_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const int16_t& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const int& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned int& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const long long& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const unsigned long long& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const long int& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const float& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const double& a, const float16& b) { return static_cast(a) == b; } - local_def bool operator==(const long unsigned int& a, const float16& b) { return static_cast(a) == b; } + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } - local_def bool operator!=(const float16& a, const float& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const double& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const int& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const long long& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const unsigned long long& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const long int& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const int8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const uint8_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const int16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const uint16_t& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const bool& b) { return a != static_cast(b); } - local_def bool operator!=(const float16& a, const long unsigned int& b) { return a != static_cast(b); } - local_def bool operator!=(const bool& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int8_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint8_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int16_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const uint16_t& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const int& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned int& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long long& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const unsigned long long& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long int& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const float& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const double& a, const float16& b) { return static_cast(a) != b; } - local_def bool operator!=(const long unsigned int& a, const float16& b) { return static_cast(a) != b; } + local_def float16& operator=(const float& rhs) { + #ifdef __CUDA_ARCH__ + auto t = __float2half_rn(rhs); + auto b = *(data.getXP()); - local_def bool operator<(const float16& a, const float& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const double& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const int& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const long long& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const unsigned long long& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const long int& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const int8_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const uint8_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const int16_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const uint16_t& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const bool& b) { return a < static_cast(b); } - local_def bool operator<(const float16& a, const long unsigned int& b) { return a < static_cast(b); } - local_def bool operator<(const bool& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const int8_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint8_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const int16_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const uint16_t& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const int& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned int& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const long long& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const unsigned long long& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const long int& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const float& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const double& a, const float16& b) { return static_cast(a) < b; } - local_def bool operator<(const long unsigned int& a, const float16& b) { return static_cast(a) < b; } + #ifdef CUDA_8 + *(data.getXP()) = t; + #else + data.assign(t); + #endif - local_def bool operator>(const float16& a, const float& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const double& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const int& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const long long& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const unsigned long long& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const long int& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const int8_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const uint8_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const int16_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const uint16_t& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const bool& b) { return a > static_cast(b); } - local_def bool operator>(const float16& a, const long unsigned int& b) { return a > static_cast(b); } - local_def bool operator>(const bool& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const int8_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint8_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const int16_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const uint16_t& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const int& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned int& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const long long& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const unsigned long long& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const long int& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const float& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const double& a, const float16& b) { return static_cast(a) > b; } - local_def bool operator>(const long unsigned int& a, const float16& b) { return static_cast(a) > b; } - - local_def bool operator<=(const float16& a, const float& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const double& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const int& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const unsigned long long& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const long int& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const int8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const uint8_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const int16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const uint16_t& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const bool& b) { return a <= static_cast(b); } - local_def bool operator<=(const float16& a, const long unsigned int& b) { return a <= static_cast(b); } - local_def bool operator<=(const bool& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int8_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint8_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int16_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const uint16_t& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const int& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned int& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long long& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const unsigned long long& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long int& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const float& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const double& a, const float16& b) { return static_cast(a) <= b; } - local_def bool operator<=(const long unsigned int& a, const float16& b) { return static_cast(a) <= b; } + #else + data = cpu_float2ihalf_rn(rhs); + #endif - local_def bool operator>=(const float16& a, const float& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const double& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const int& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const unsigned long long& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const long int& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const int8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const uint8_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const int16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const uint16_t& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const bool& b) { return a >= static_cast(b); } - local_def bool operator>=(const float16& a, const long unsigned int& b) { return a >= static_cast(b); } - local_def bool operator>=(const bool& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int8_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint8_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int16_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const uint16_t& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const int& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned int& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long long& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const unsigned long long& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long int& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const float& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const double& a, const float16& b) { return static_cast(a) >= b; } - local_def bool operator>=(const long unsigned int& a, const float16& b) { return static_cast(a) >= b; } - + return *this; + } - local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { - os << static_cast(f); - return os; - } + local_def float16& operator=(const unsigned short rhs) { + *data.getXP() = rhs; + return *this; + } - local_def float16 operator+(const float16& h) { return h; } + local_def float16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f: 0.f; + return *this; + } - local_def float16 operator - (const float16& h) { - const ihalf * tmp = &h.data; - return float16(hneg(tmp->getX())); -} + local_def float16& operator=(const ihalf& rhs) { + *data.getXP() = ((ihalf) rhs).getX(); + return *this; + } + + #ifdef __CUDACC__ + local_def float16& operator=(const half& rhs) { + data.assign(rhs); + return *this; + } + #endif + + local_def float16& operator=(const float16& rhs) { + data = rhs.data; + return *this; + } + + template ::value || std::is_same::value>::type> + local_def float16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } + + #ifdef NATIVE_HALFS + local_def friend bool operator==(const float16& a, const float16& b) { return __hequ(a.data, b.data); } + #else + local_def friend bool operator==(const float16& a, const float16& b) { return ishequ_(((ihalf) a.data).getX(), ((ihalf)b.data).getX()); } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator!=(const float16& a, const float16& b) { return !(__hequ(a.data, b.data)); } + #else + local_def friend bool operator!=(const float16& a, const float16& b) { return !(a == b); } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator<(const float16& a, const float16& b) { return __hlt(a.data, b.data); } + #else + local_def friend bool operator<(const float16& a, const float16& b) { return (float)a < (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator>(const float16& a, const float16& b) { return __hgt(a.data, b.data); } + #else + local_def friend bool operator>(const float16& a, const float16& b) { return (float)a > (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator<=(const float16& a, const float16& b) { return __hle(a.data, b.data); } + #else + local_def friend bool operator<=(const float16& a, const float16& b) { return (float)a <= (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend bool operator>=(const float16& a, const float16& b) { return __hge(a.data, b.data); } + #else + local_def friend bool operator>=(const float16& a, const float16& b) { return (float)a >= (float)b; } + #endif + + #ifdef NATIVE_HALFS + local_def friend float16 operator+(const float16& a, const float16& b) { return __hadd(a.data, b.data); } + + local_def friend float16 operator-(const float16& a, const float16& b) { return __hsub(a.data, b.data); } + + local_def friend float16 operator*(const float16& a, const float16& b) { return __hmul(a.data, b.data); } + + local_def friend float16 operator/(const float16& a, const float16& b) { + #ifdef CUDA_8 + return hdiv(a.data, b.data); + #else + return __hdiv(a.data, b.data); + #endif + } + #else + local_def friend float16 operator+(const float16& a, const float16& b) { return float16((float)a + (float)b); } + local_def friend float16 operator-(const float16& a, const float16& b) { return float16((float)a - (float)b); } + local_def friend float16 operator*(const float16& a, const float16& b) { return float16((float)a * (float)b); } + local_def friend float16 operator/(const float16& a, const float16& b) { return float16((float)a / (float)b); } + #endif + + template ::value>::type> + local_def friend float16 operator+(const float16& a, const T& b) { return a + static_cast(b); } + template ::value>::type> + local_def friend float16 operator+(const T& a, const float16& b) { return static_cast(a) + b; } + + template ::value>::type> + local_def friend float16 operator-(const float16& a, const T& b) { return a - static_cast(b); } + template ::value>::type> + local_def friend float16 operator-(const T& a, const float16& b) { return static_cast(a) - b; } + + template ::value>::type> + local_def friend float16 operator*(const float16& a, const T& b) { return a * static_cast(b); } + template ::value>::type> + local_def friend float16 operator*(const T& a, const float16& b) { return static_cast(a) * b; } + + template ::value>::type> + local_def friend float16 operator/(const float16& a, const T& b) { return a / static_cast(b); } + template ::value>::type> + local_def friend float16 operator/(const T& a, const float16& b) { return static_cast(a) / b; } + + template ::value>::type> + local_def friend bool operator==(const float16& a, const T& b) { return a == static_cast(b); } + template ::value>::type> + local_def friend bool operator==(const T& a, const float16& b) { return static_cast(a) == b; } + + template ::value>::type> + local_def friend bool operator!=(const float16& a, const T& b) { return a != static_cast(b); } + template ::value>::type> + local_def friend bool operator!=(const T& a, const float16& b) { return static_cast(a) != b; } + + template ::value>::type> + local_def friend bool operator<(const float16& a, const T& b) { return a < static_cast(b); } + template ::value>::type> + local_def friend bool operator<(const T& a, const float16& b) { return static_cast(a) < b; } + + template ::value>::type> + local_def friend bool operator>(const float16& a, const T& b) { return a > static_cast(b); } + template ::value>::type> + local_def friend bool operator>(const T& a, const float16& b) { return static_cast(a) > b; } + + template ::value>::type> + local_def friend bool operator<=(const float16& a, const T& b) { return a <= static_cast(b); } + template ::value>::type> + local_def friend bool operator<=(const T& a, const float16& b) { return static_cast(a) <= b; } + + template ::value>::type> + local_def friend bool operator>=(const float16& a, const T& b) { return a >= static_cast(b); } + template ::value>::type> + local_def friend bool operator>=(const T& a, const float16& b) { return static_cast(a) >= b; } + + local_def float16& operator+=(float16 rhs) { *this = (float)*this + (float)rhs; return *this; } + + local_def float16& operator-=(float16 rhs) { *this = (float)*this - (float)rhs; return *this; } + + local_def float16& operator*=(float16 rhs) { *this = (float)*this * (float)rhs; return *this; } + + local_def float16& operator/=(float16 rhs) { *this = (float)*this / (float)rhs; return *this; } + + template ::value>::type> + local_def float16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } + + template ::value>::type> + local_def float16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } + + template ::value>::type> + local_def float16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } + + template ::value>::type> + local_def float16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } + + local_def float16& operator++() { *this = *this + (float16)1.f; return *this; } + + local_def float16& operator--() { *this = *this - (float16)1.f; return *this; } + + local_def float16 operator++(int) { *this = *this + (float16)1.f; return *this; } + + local_def float16 operator--(int) { *this = *this - (float16)1.f; return *this; } + + local_def float16 operator-() const { + return 0.f - (float)*this; + } + + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } +}; + + + + // local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { + // os << static_cast(f); + // return os; + // } + + // local_def float16 operator+(const float16& h) { return h; } + + // local_def float16 operator - (const float16& h) { + // const ihalf * tmp = &h.data; + // return float16(hneg(tmp->getX())); + // } #ifdef __CUDACC__ local_def int isnan(const float16& h) { return ishnan_(((ihalf)h.data).getX()); } @@ -730,6 +482,6 @@ local_def ihalf cpu_float2ihalf_rn(float f) local_def int isinf(const float16& h) { return ishinf_(((ihalf)h.data).getX()); } #endif - std::ostream& operator << (std::ostream& s, const float16&); + // std::ostream& operator << (std::ostream& s, const float16&); #endif diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 238c2f15d..655683687 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -43,7 +43,7 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) { //exp.printIndexedBuffer("E B"); - exp.applyBroadcast(broadcast::Add, {1}, &y); + exp.applyBroadcast(broadcast::Add, {1}, y, exp); nd4j::ops::add op; auto result = op.execute({&x, &y}, {}, {}, {}); @@ -70,7 +70,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) { y.linspace(1); exp.linspace(1); - exp.applyBroadcast(broadcast::Multiply, {1}, &y); + exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); nd4j::ops::multiply op; auto result = op.execute({&x, &y}, {}, {}, {}); @@ -94,7 +94,7 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) { y.linspace(1); exp.linspace(1); - exp.applyBroadcast(broadcast::SquaredSubtract, {1}, &y); + exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); nd4j::ops::squaredsubtract op; @@ -856,7 +856,7 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { z.printIndexedBuffer(); */ - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); //z.printIndexedBuffer(); @@ -874,7 +874,7 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { z.assign(119.f); e.assign(2.f); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp index 9134ef0a4..e025aaead 100644 --- a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp @@ -120,7 +120,7 @@ TEST_F(ConstantShapeHelperTests, basic_test_3) { TEST_F(ConstantShapeHelperTests, basic_test_4) { auto array = NDArrayFactory::create_('c', {128, 256}); - auto dup = array->dup('f'); + auto dup = new NDArray(array->dup('f')); ASSERT_TRUE(dup->shapeInfo() != nullptr); @@ -165,12 +165,11 @@ TEST_F(ConstantShapeHelperTests, basic_test_7) { IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); auto strided = array->subarray(indices); - strided->assign(1.0f); + strided.assign(1.0f); //strided->printIndexedBuffer("column"); delete array; - delete strided; } TEST_F(ConstantHelperTests, basic_test_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 42e141f46..13316fe8d 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -130,7 +130,7 @@ TEST_F(ContextTests, Basic_Test_5) { auto _20 = NDArrayFactory::create_('c', {2, 2}); _20->linspace(1); - auto exp = _20->dup(); + auto exp = new NDArray(_20->dup()); ctx.pushNDArrayToVariableSpace(1, 1, _20); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index eccb73c6c..667eeddce 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -422,8 +422,8 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, - 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, - 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, + 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, @@ -443,9 +443,9 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { weightsD.permutei({2,3,1,0}); weightsP.permutei({2,3,1,0}); - input.applyScalar(scalar::Divide, 100.0); - weightsD.applyScalar(scalar::Divide, 100.0); - weightsP.applyScalar(scalar::Divide, 100.0); + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); nd4j::ops::sconv2d op; @@ -653,7 +653,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { weightsD.permutei({2,3,1,0}); weightsP.permutei({2,3,1,0}); - weightsP.applyScalar(scalar::Divide, 10000.0); + weightsP.applyScalar(scalar::Divide, 10000.0, weightsP); nd4j::ops::sconv2d op; auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); @@ -793,7 +793,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { nd4j::ops::conv1d_bp op_bp; - auto epsilonNxt = z->dup(); + auto epsilonNxt = new NDArray(z->dup()); epsilonNxt->linspace(1); auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index de3cdcdba..d8e6379b1 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -110,13 +110,13 @@ TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); input = 0.5; weights.linspace(0.1, 0.1); @@ -240,10 +240,10 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { weightsD.permutei({2,3,1,0}); weightsP.permutei({2,3,1,0}); - input.applyScalar(scalar::Divide, 100.0); - weightsD.applyScalar(scalar::Divide, 100.0); - weightsP.applyScalar(scalar::Divide, 100.0); - epsilonNext.applyScalar(scalar::Divide, 100.0); + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); nd4j::ops::sconv2d_bp op; auto resultBP = op.execute({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); @@ -1132,11 +1132,11 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, - 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, - 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, - 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, - 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, + 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, + 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); input.linspace(1.); @@ -1160,8 +1160,8 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, - 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, + 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); input.linspace(1.); @@ -1185,23 +1185,23 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, - 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, - 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, - 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, - 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, - 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, - 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, - 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, - 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, - 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, - 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, - 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, - 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, - 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, - 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, - 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, - 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, + 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, + 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, + 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, + 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, + 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, + 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, + 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, + 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, + 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, + 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, + 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, + 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, + 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, + 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, + 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, + 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); input.linspace(1.); @@ -1225,7 +1225,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); input.linspace(1.); @@ -1249,11 +1249,11 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, - 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); input.linspace(1.); @@ -1277,7 +1277,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); input.linspace(1.); @@ -1301,13 +1301,13 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { int dataFormat = 0; // -NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, - 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, - 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, - 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, - 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, - 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, - 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, + 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, + 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, + 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, + 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, + 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, + 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f}); input.linspace(1.); @@ -1332,14 +1332,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); input.linspace(1.); gradO = 2.; @@ -1366,14 +1366,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); input.linspace(1.); gradO = 2.; @@ -1402,13 +1402,13 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); input.linspace(1.); gradO = 2.; @@ -1434,13 +1434,13 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, - 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, - 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, + 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); input.linspace(1.); gradO = 2.; @@ -1466,11 +1466,11 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1496,14 +1496,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1529,13 +1529,13 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, - 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, - 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1562,12 +1562,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, - 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1651,8 +1651,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1678,8 +1678,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, - 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, + 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1705,8 +1705,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, - 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, + 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1732,8 +1732,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, - 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, + 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1841,11 +1841,11 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, - 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, - 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, - 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, - 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, + 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, + 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, + 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, + 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f }); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1872,11 +1872,11 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, - 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, - 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, - 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, - 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, + 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, + 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, + 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, + 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1903,9 +1903,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, - 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, - 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, + 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, + 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1932,9 +1932,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, - 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, - 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, + 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, + 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1994,11 +1994,11 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, - 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, - 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, - 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, - 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, + 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, + 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, + 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, + 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -2029,9 +2029,9 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, - 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, - 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, - 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, + 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, + 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, + 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); input.linspace(1.); gradO.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 45b35eb4e..b646a493d 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -126,8 +126,8 @@ TEST_F(DataTypesValidationTests, cast_1) { float16 x = static_cast(1.f); float y = static_cast(x); - ASSERT_TRUE(1.f == x); - ASSERT_TRUE(y == x); + ASSERT_TRUE(static_cast(1.f) == x); + ASSERT_TRUE(y == static_cast(x)); } TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 591746804..1e43081c1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { x.assign(3.f); y.assign(1.f); exp.assign(-2.f); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), &y, &z, true); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); ASSERT_TRUE(exp.equalsTo(&z)); @@ -811,7 +811,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { x.assign(1); y.assign(3); exp.assign(2); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), &y, &z, true); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); ASSERT_TRUE(z.equalsTo(&exp)); nd4j::ops::reversesubtract subOp; @@ -833,10 +833,10 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { x.assign(2.); y.assign(9.f); exp.assign(1.f); - y.applyTrueBroadcast(BROADCAST(Mod), &x, &z, true); + y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); ASSERT_TRUE(exp.equalsTo(&z)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), &y, &exp, true); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); ASSERT_TRUE(exp.equalsTo(&z)); nd4j::ops::reversemod subOp; @@ -861,9 +861,9 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { x.assign(2.f); y.assign(9.f); exp.assign(1.f); - x.applyTrueBroadcast(BROADCAST(ReverseMod), &y, &z, true); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); ASSERT_TRUE(z.equalsTo(&exp)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), &y, &exp, true); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); ASSERT_TRUE(z.equalsTo(&exp)); nd4j::ops::reversemod subOp; @@ -1218,8 +1218,8 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { ASSERT_TRUE(res->at(0)->equalsTo(exp)); auto z(exp); - x.applyTrueBroadcast(BROADCAST(ReverseDivide), &y, &z, true); - y.applyTrueBroadcast(BROADCAST(Divide), &x, &exp, true); + x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); ASSERT_TRUE(z.equalsTo(&exp)); @@ -1759,7 +1759,7 @@ TEST_F(DeclarableOpsTests1, Transpose1) { Nd4jStatus status = transpose.execute(block); ASSERT_EQ(ND4J_STATUS_OK, status); - // ASSERT_TRUE(x.isSameShapeStrict(&exp)); + // ASSERT_TRUE(x.isSameShapeStrict(exp)); for (int e = 0; e < x->rankOf() * 2 + 2; e++) { ASSERT_EQ(x->getShapeInfo()[e], exp->getShapeInfo()[e]); @@ -1790,7 +1790,7 @@ TEST_F(DeclarableOpsTests1, Transpose2) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // ASSERT_TRUE(result->isSameShapeStrict(&exp)); + // ASSERT_TRUE(result->isSameShapeStrict(exp)); for (int e = 0; e < result->rankOf() * 2 + 2; e++) { ASSERT_EQ(result->getShapeInfo()[e], exp->getShapeInfo()[e]); } @@ -1828,7 +1828,7 @@ TEST_F(DeclarableOpsTests1, Permute1) { Nd4jStatus status = permute.execute(block); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(x->isSameShapeStrict(exp)); + ASSERT_TRUE(x->isSameShapeStrict(*exp)); delete exp; delete block; @@ -1863,7 +1863,7 @@ TEST_F(DeclarableOpsTests1, Permute2) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(result->isSameShapeStrict(exp)); + ASSERT_TRUE(result->isSameShapeStrict(*exp)); delete block; delete variableSpace; @@ -2468,7 +2468,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { NDArray expGradX('c', {N,bS,2*K}, expGradXBuff); NDArray expGradW('c', {N,2*K,6*K}, expGradWBuff); auto expGradB = NDArrayFactory::create('c', {4*K}); - gradBias.reduceAlongDimension(reduce::Sum, &expGradB, {0}); // [bS, 4K] -> [4K] + gradBias.reduceAlongDimension(reduce::Sum, expGradB, {0}); // [bS, 4K] -> [4K] NDArray expGradInit('c', {bS,2*K}, expGradInitBuff); input.assign(1.5); @@ -2827,7 +2827,7 @@ TEST_F(DeclarableOpsTests1, Stack_1) { auto results = op.execute({&input1, &input2}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2855,7 +2855,7 @@ TEST_F(DeclarableOpsTests1, Stack_2) { auto results = op.execute({&input1, &input2}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2883,7 +2883,7 @@ TEST_F(DeclarableOpsTests1, Stack_3) { auto results = op.execute({&input1, &input2}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2910,7 +2910,7 @@ TEST_F(DeclarableOpsTests1, Stack_4) { auto results = op.execute({&input1, &input2}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2937,7 +2937,7 @@ TEST_F(DeclarableOpsTests1, Stack_5) { auto results = op.execute({&input1, &input2}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2964,7 +2964,7 @@ TEST_F(DeclarableOpsTests1, Stack_6) { auto results = op.execute({&input1, &input2}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -2988,7 +2988,7 @@ TEST_F(DeclarableOpsTests1, Stack_7) { auto results = op.execute({&input1, &input1, &input1}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3011,7 +3011,7 @@ TEST_F(DeclarableOpsTests1, Stack_8) { auto results = op.execute({&input1, &input1, &input1}, {}, {0}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3034,7 +3034,7 @@ TEST_F(DeclarableOpsTests1, Stack_9) { auto results = op.execute({&input1, &input1, &input1}, {}, {1}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3060,7 +3060,7 @@ TEST_F(DeclarableOpsTests1, Stack_10) { //expected.printShapeInfo("exp"); //output->printShapeInfo("out"); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3082,7 +3082,7 @@ TEST_F(DeclarableOpsTests1, Stack_11) { auto results = op.execute({&input1, &input1, &input1}, {}, {}); auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3370,7 +3370,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3395,7 +3395,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(&input)); + ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); delete results; @@ -3421,7 +3421,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3447,7 +3447,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3472,7 +3472,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3498,7 +3498,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(&input)); + ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); delete results; @@ -3526,7 +3526,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) { //expected.printIndexedBuffer("E"); //result->printIndexedBuffer("R"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3554,7 +3554,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8 ) { auto result = results->at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3579,7 +3579,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3618,7 +3618,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3640,7 +3640,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12 ) { auto result = results->at(0); //result->printIndexedBuffer("Result reverse"); //expected.printIndexedBuffer("Expected reverse"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3661,7 +3661,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -3682,7 +3682,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14 ) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 21c18299e..66cc487e1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2356,7 +2356,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2378,7 +2378,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { auto result = results->at(0); // result->printIndexedBuffer("REDUCE_LOGSUMEXP"); // expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2398,7 +2398,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { auto result = results->at(0); // result->printIndexedBuffer("REDUCE_LOGSUMEXP"); // expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2419,7 +2419,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { NDArray* result = results->at(0); //result->printIndexedBuffer("OOOOUUUUTTT"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2440,7 +2440,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2462,7 +2462,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput3"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2485,7 +2485,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2507,7 +2507,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2531,7 +2531,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput6"); // result->printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2555,7 +2555,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppression OUtput06"); // result->printShapeInfo("Ouput06 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2602,7 +2602,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppressionOverlap1 Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2627,7 +2627,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2652,7 +2652,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { NDArray* result = results->at(0); // result->printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2677,7 +2677,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { auto result = results->at(0); // result->printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2701,7 +2701,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2725,7 +2725,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2749,7 +2749,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { auto result = results->at(0); // result->printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2773,7 +2773,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); //ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2811,7 +2811,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { result->syncToHost(); // result->printBuffer("Bounded boxes"); // expected.printBuffer("Bounded expec"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2844,7 +2844,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { // result->syncToHost(); // result->printBuffer("Bounded boxes 2"); // expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2899,7 +2899,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { // result->syncToHost(); // result->printBuffer("Bounded boxes 2"); // expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2921,7 +2921,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { auto result = results->at(0); // result->printBuffer("Quantized"); // exp.printBuffer("Expected"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -2941,7 +2941,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { auto result = results->at(0); // result->printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -2962,7 +2962,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { auto result = results->at(0); // result->printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -2986,7 +2986,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { auto result = results->at(0); // result->printIndexedBuffer("Quantized03"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3009,7 +3009,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { auto result = results->at(0); // result->printIndexedBuffer("Quantized03_1"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3033,7 +3033,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { auto result = results->at(0); result->printIndexedBuffer("Quantized03_2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3056,7 +3056,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { auto result = results->at(0); result->printIndexedBuffer("Quantized03_3"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3094,7 +3094,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { // exp.printBuffer("Quantized per channest E"); // auto diff = *result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3148,7 +3148,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { // auto diff = *result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3182,7 +3182,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { // auto diff = *result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3225,7 +3225,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { auto result = results->at(0); // result->printBuffer("Quantized7"); // exp.printBuffer("Expected 7"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3251,7 +3251,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { // x.printBuffer("SourInput8"); // result->printBuffer("Quantized8"); // exp.printBuffer("Expected 8"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); ASSERT_TRUE(exp.equalsTo(result)); delete results; @@ -3279,7 +3279,7 @@ TEST_F(DeclarableOpsTests10, batchnorm_test1) { auto output = results->at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3294,7 +3294,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { auto gamma = NDArrayFactory::create('c', {4}); auto beta = NDArrayFactory::create('c', {4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, + auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); input.linspace(0.1, 0.1); @@ -3312,7 +3312,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { auto output = results->at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3327,7 +3327,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); input.linspace(0.1, 0.1); @@ -3340,7 +3340,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3355,7 +3355,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); input.linspace(0.1, 0.1); @@ -3368,7 +3368,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3397,7 +3397,7 @@ TEST_F(DeclarableOpsTests10, batchnorm_test5) { auto output = results->at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3425,7 +3425,7 @@ TEST_F(DeclarableOpsTests10, batchnorm_test6) { auto output = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete results; @@ -3441,7 +3441,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { NDArray result('c', {2,2,2}, nd4j::DataType::BOOL); - arr1.applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); + arr1.applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), arr2, result, true); // result.printIndexedBuffer(); // expd.printIndexedBuffer(); @@ -3474,7 +3474,7 @@ TEST_F(DeclarableOpsTests10, printIndexedTest_1) { // [[5 6] // [7 8]]] // - ResultSet* lastDims = arr.allTensorsAlongDimension({3}); // last dim + ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim size_t k = 0; // k from 0 to lastDims->size() Nd4jLong rank = 4; // in this case printf("["); @@ -3488,15 +3488,13 @@ TEST_F(DeclarableOpsTests10, printIndexedTest_1) { // printf("["); // else // printf(" "); - lastDims->at(k++)->printBuffer(); + lastDims.at(k++)->printBuffer(); //if (k == arr.sizeAt(i)) // printf("]\n"); } printf("]\n"); } printf("]\n"); - delete lastDims; - } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 37bcba233..2b01eca79 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -2785,11 +2785,11 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) { NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - auto sumDiff = labels.reduceAlongDims(reduce::Sum, {1}, true); + auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); NDArray numOfNonZero(sumDiff.getShapeInfo(), nd4j::DataType::INT64, false); numOfNonZero.assign(1); - sumDiff.applyPairwiseTransform(pairwise::SafeDivide, &numOfNonZero, &sumDiff, nullptr); + sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); } ///////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 0710e5506..c1cb872b4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -670,7 +670,7 @@ TEST_F(DeclarableOpsTests12, relu_1) { Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(&z)); + ASSERT_TRUE(expected.isSameShapeStrict(z)); ASSERT_TRUE(expected.equalsTo(z)); } @@ -825,7 +825,7 @@ TEST_F(DeclarableOpsTests12, pullRows_1) { TEST_F(DeclarableOpsTests12, pullRows_2) { NDArray arr('f', {5, 2}, {0,1,2,3,4,5,6,7,8,9}); - NDArray* y = arr.dup('c'); + NDArray* y = new NDArray(arr.dup('c')); NDArray x = (*y)({0,0, 0,1}, true); // view, points on first column of y, shape is {5,1} NDArray z('c', {4, 1}, nd4j::DataType::DOUBLE); @@ -858,7 +858,7 @@ TEST_F(DeclarableOpsTests12, pullRows_2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, softmax_9) { NDArray arrC('c', {5,2}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, nd4j::DataType::FLOAT32); - NDArray* arrF = arrC.dup('f'); + NDArray* arrF = new NDArray(arrC.dup('f')); NDArray outCC('c', {5,2}, nd4j::DataType::FLOAT32); NDArray outCF('f', {5,2}, nd4j::DataType::FLOAT32); @@ -1395,7 +1395,7 @@ TEST_F(DeclarableOpsTests12, pad_tests1) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1422,7 +1422,7 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1449,7 +1449,7 @@ TEST_F(DeclarableOpsTests12, pad_tests3) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1462,10 +1462,10 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, - 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, - 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, + 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, + 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); @@ -1480,7 +1480,7 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); // for(int i = 0; i < expected.lengthOf(); ++i) { @@ -1514,7 +1514,7 @@ TEST_F(DeclarableOpsTests12, pad_tests5) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1541,7 +1541,7 @@ TEST_F(DeclarableOpsTests12, pad_tests6) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1567,7 +1567,7 @@ TEST_F(DeclarableOpsTests12, pad_tests7) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1593,7 +1593,7 @@ TEST_F(DeclarableOpsTests12, pad_tests8) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1619,7 +1619,7 @@ TEST_F(DeclarableOpsTests12, pad_tests9) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1641,7 +1641,7 @@ TEST_F(DeclarableOpsTests12, pad_tests10) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1663,7 +1663,7 @@ TEST_F(DeclarableOpsTests12, pad_tests11) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1692,7 +1692,7 @@ TEST_F(DeclarableOpsTests12, pad_tests12) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1714,7 +1714,7 @@ TEST_F(DeclarableOpsTests12, pad_tests13) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1735,7 +1735,7 @@ TEST_F(DeclarableOpsTests12, pad_tests14) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1756,7 +1756,7 @@ TEST_F(DeclarableOpsTests12, pad_tests15) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1777,7 +1777,7 @@ TEST_F(DeclarableOpsTests12, pad_tests16) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1798,7 +1798,7 @@ TEST_F(DeclarableOpsTests12, pad_tests17) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1819,7 +1819,7 @@ TEST_F(DeclarableOpsTests12, pad_tests18) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1840,7 +1840,7 @@ TEST_F(DeclarableOpsTests12, pad_tests19) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests12, pad_tests20) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1884,7 +1884,7 @@ TEST_F(DeclarableOpsTests12, pad_tests21) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1907,7 +1907,7 @@ TEST_F(DeclarableOpsTests12, pad_tests22) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1931,7 +1931,7 @@ TEST_F(DeclarableOpsTests12, pad_tests23) { // result->printShapeInfo("r"); // expected.printShapeInfo("e"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1953,7 +1953,7 @@ TEST_F(DeclarableOpsTests12, pad_tests24) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1975,7 +1975,7 @@ TEST_F(DeclarableOpsTests12, pad_tests25) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -1997,7 +1997,7 @@ TEST_F(DeclarableOpsTests12, pad_tests26) { auto result = results->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2017,7 +2017,7 @@ TEST_F(DeclarableOpsTests12, pad_tests27) { // z.printIndexedBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.isSameShapeStrict(&z)); + ASSERT_TRUE(exp.isSameShapeStrict(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -2143,7 +2143,7 @@ TEST_F(DeclarableOpsTests12, pad_tests34) { Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(&z)); + ASSERT_TRUE(expected.isSameShapeStrict(z)); ASSERT_TRUE(expected.equalsTo(z)); } @@ -2167,7 +2167,7 @@ TEST_F(DeclarableOpsTests12, Pad_1) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2194,7 +2194,7 @@ TEST_F(DeclarableOpsTests12, Pad_2) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2221,7 +2221,7 @@ TEST_F(DeclarableOpsTests12, Pad_3) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2248,7 +2248,7 @@ TEST_F(DeclarableOpsTests12, Pad_4) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2275,7 +2275,7 @@ TEST_F(DeclarableOpsTests12, Pad_5) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2302,7 +2302,7 @@ TEST_F(DeclarableOpsTests12, Pad_6) { auto result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2328,7 +2328,7 @@ TEST_F(DeclarableOpsTests12, Pad_7) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2354,7 +2354,7 @@ TEST_F(DeclarableOpsTests12, Pad_8) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; @@ -2380,7 +2380,7 @@ TEST_F(DeclarableOpsTests12, Pad_9) auto *result = results->at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 044a41a37..c95599ff3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests13, test_or_1) { NDArray z('c', {4}, nd4j::DataType::BOOL); - x.applyPairwiseTransform(pairwise::Or, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Or, y, z); ASSERT_EQ(e, z); } @@ -181,7 +181,7 @@ TEST_F(DeclarableOpsTests13, test_and_1) { auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::And, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::And, y, z); ASSERT_EQ(e, z); } @@ -193,7 +193,7 @@ TEST_F(DeclarableOpsTests13, test_xor_1) { auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::Xor, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Xor, y, z); ASSERT_EQ(e, z); } @@ -1030,10 +1030,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, - 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, - 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, - 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, + auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, + 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, + 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, + 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f}); auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index cc72321e9..1f45d843f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) { TEST_F(DeclarableOpsTests15, Test_standarize_1) { auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5}, {0.f, 0f, 0.f, 0.f, 0.f}); + auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::standardize op; auto result = op.execute({&x}, {&x}, {}, {0}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 303580205..fb035bff1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -1,659 +1,657 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - - - // - // @author raver119@gmail.com - // - -#include "testlayers.h" -#include -#include -#include -#include -#include - - -using namespace nd4j; - - -class DeclarableOpsTests16 : public testing::Test { -public: - - DeclarableOpsTests16() { - printf("\n"); - fflush(stdout); - } -}; - -TEST_F(DeclarableOpsTests16, scatter_upd_1) { - auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); - auto y = NDArrayFactory::create(0); - auto w = NDArrayFactory::create(3.0f); - auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); - - nd4j::ops::scatter_upd op; - auto result = op.execute({ &x, &y, &w }, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - ASSERT_EQ(e, *z); - - delete result; -} - -TEST_F(DeclarableOpsTests16, scatter_upd_2) { - - NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); - NDArray indices('c', { 2 }, { 2,5 }, nd4j::DataType::INT32); - NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); - NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, nd4j::DataType::FLOAT32); - - x.linspace(1); - - nd4j::ops::scatter_upd op; - auto result = op.execute({ &x, &indices, &updates }, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - ASSERT_EQ(e, *z); - - delete result; -} - -TEST_F(DeclarableOpsTests16, scatter_upd_3) { - - NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); - NDArray indices('c', { 2 }, { 20,5 }, nd4j::DataType::INT32); - NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); - NDArray output('c', { 10, 3 }, nd4j::DataType::FLOAT32); - - nd4j::ops::scatter_upd op; - ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); -} - -TEST_F(DeclarableOpsTests16, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); - auto z = NDArrayFactory::create(0.0f); - auto e = NDArrayFactory::create(3.0f); - - nd4j::ops::size op; - auto status = op.execute({ &x }, { &z }, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); -} - -TEST_F(DeclarableOpsTests16, test_empty_noop_1) { - auto z = NDArrayFactory::empty(); - - nd4j::ops::noop op; - auto status = op.execute({}, { &z }, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_empty_noop_2) { - auto z = NDArrayFactory::empty(); - - Context ctx(1); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - nd4j::ops::noop op; - auto status = op.execute(&ctx); - - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_svd_1) { - auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); - auto z = NDArrayFactory::create('c', { 3 }); - - nd4j::ops::svd op; - auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); - - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { - auto x = NDArrayFactory::create({ 37, 37, 37 }); - auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); - auto e = NDArrayFactory::create(18); - - nd4j::ops::bits_hamming_distance op; - auto result = op.execute({ &x, &y }, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - ASSERT_EQ(e, *z); - - delete result; -} - -TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { - auto input = NDArrayFactory::create('c', { 512 }); - auto low = NDArrayFactory::create('c', { 512 }); - auto high = NDArrayFactory::create('c', { 512 }); - - auto output = NDArrayFactory::create(0.0f); - - input.linspace(1.0); - low.linspace(1.0); - high.linspace(1.0); - - nd4j::ops::knn_mindistance op; - auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); -} - -TEST_F(DeclarableOpsTests16, test_empty_cast_1) { - auto x = NDArrayFactory::create('c', { 1, 0, 2 }); - auto e = NDArrayFactory::create('c', { 1, 0, 2 }); - - nd4j::ops::cast op; - auto result = op.execute({ &x }, {}, { 10 }); - ASSERT_EQ(Status::OK(), result->status()); - ASSERT_EQ(e, *result->at(0)); - - delete result; -} - -TEST_F(DeclarableOpsTests16, test_range_1) { - nd4j::ops::range op; - auto z = NDArrayFactory::create('c', { 200 }); - - Context ctx(1); - ctx.setTArguments({ -1.0, 1.0, 0.01 }); - ctx.setOutputArray(0, &z); - - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_range_2) { - nd4j::ops::range op; - auto z = NDArrayFactory::create('c', { 200 }); - - double tArgs[] = { -1.0, 1.0, 0.01 }; - - auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); - shape::printShapeInfoLinear("Result", shapes->at(0)); - ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); - - delete shapes; -} - -TEST_F(DeclarableOpsTests16, test_reverse_1) { - std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; - std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; - - for (auto r : rows) { - for (auto c : columns) { - //nd4j_printf("Trying [%i, %i]\n", r, c); - auto array = NDArrayFactory::create('c', { r, c }); - auto exp = NDArrayFactory::create('c', { r, c }); - auto reversed = NDArrayFactory::create('c', { r, c }); - - auto rowOriginal = NDArrayFactory::create('c', { c }); - auto rowReversed = NDArrayFactory::create('c', { c }); - - for (int e = 0; e < c; e++) { - rowOriginal.p(e, (float)e); - rowReversed.p(c - e - 1, (float)e); - } - - - auto listI = array.allTensorsAlongDimension({ 1 }); - auto listE = exp.allTensorsAlongDimension({ 1 }); - - for (int e = 0; e < r; e++) { - listI->at(e)->assign(rowOriginal); - listE->at(e)->assign(rowReversed); - } - - delete listI; - delete listE; - - nd4j::ops::reverse op; - Nd4jLong axis = 1; - auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(exp, reversed); - } - } -} - -TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { - /* - test case generated by python colorsys and scaled to suit our needs - from colorsys import * - from random import * - import numpy as np - rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) - hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) - rgbs.ravel() - hsvs.ravel() - */ - auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, - 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, - 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f, - 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, - 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f, - 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f, - 0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f, - 0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, - 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f, - 0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f, - 0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f, - 0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, - 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, - 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, - 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, - 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, - 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, - 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, - 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, - 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, - 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, - 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, - 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f - }); - - - auto actual = NDArrayFactory::create('c', { 5,4,3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - - nd4j::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); -#if 0 - //visual check - rgbs.printBuffer("rgbs "); - actual.printBuffer("HSV "); - expected.printBuffer("exp"); -#endif - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { - /* - swapped_rgbs=rgbs.swapaxes(1,2).ravel() - swapped_hsvs=hsvs.swapaxes(1,2).ravel() - */ - auto rgbs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, - 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, - 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, - 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, - 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, - 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, - 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, - 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, - 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, - 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, - 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, - 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, - 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, - 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, - 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, - 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, - 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, - 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, - 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f - }); - - - auto actual = NDArrayFactory::create('c', { 5,3,4 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 1 }); - nd4j::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { - - auto rgbs = NDArrayFactory::create('c', { 4, 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, - 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, - 0.54742825f, 0.684074104f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, - 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, - 0.199753001f, 0.684074104f - }); - - auto actual = NDArrayFactory::create('c', { 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - - nd4j::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - - -TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 0 }); - nd4j::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { - auto rgbs = NDArrayFactory::create('c', { 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); - - nd4j::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - - -TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - - //get subarray - std::unique_ptr subArrRgbs(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) })); - std::unique_ptr expected(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) })); - subArrRgbs->reshapei({ 3 }); - expected->reshapei({ 3 }); -#if 0 - //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] - subArrRgbs->printShapeInfo("subArrRgbs"); -#endif - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, subArrRgbs.get()); - ctx.setOutputArray(0, &actual); - nd4j::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected->equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { - - auto hsvs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, - 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, - 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, - 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, - 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, - 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, - 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, - 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, - 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, - 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, - 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, - 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, - 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, - 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, - 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, - 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, - 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, - 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, - 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, - 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, - 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, - 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, - 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f - }); - - - auto actual = NDArrayFactory::create('c', { 5,4,3 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - - nd4j::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { - auto hsvs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f, - 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, - 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f, - 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f, - 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f, - 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, - 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f, - 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f, - 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f, - 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f, - 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, - 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f, - 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f, - 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f, - 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, - 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f, - 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f, - 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f, - 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f - }); - auto actual = NDArrayFactory::create('c', { 5,3,4 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 1 }); - nd4j::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { - auto hsvs = NDArrayFactory::create('c', { 4, 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, - 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, - 0.332347751f, 0.111181192f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, - 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, - 0.111181192f, 0.074230373f - }); - auto actual = NDArrayFactory::create('c', { 4,3 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - - nd4j::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - - -TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f - }); - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - ctx.setIArguments({ 0 }); - nd4j::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - -TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { - - auto hsvs = NDArrayFactory::create('c', { 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); - - nd4j::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - -} - - -TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { - - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f - }); - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); - //get subarray - std::unique_ptr subArrHsvs(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) })); - subArrHsvs->reshapei({ 3 }); - std::unique_ptr expected(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) })); - expected->reshapei({ 3 }); -#if 0 - //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] - subArrHsvs->printShapeInfo("subArrHsvs"); -#endif - - Context ctx(1); - ctx.setInputArray(0, subArrHsvs.get()); - ctx.setOutputArray(0, &actual); - nd4j::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected->equalsTo(actual)); - -} +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests16 : public testing::Test { +public: + + DeclarableOpsTests16() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests16, scatter_upd_1) { + auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); + auto y = NDArrayFactory::create(0); + auto w = NDArrayFactory::create(3.0f); + auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); + + nd4j::ops::scatter_upd op; + auto result = op.execute({ &x, &y, &w }, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests16, scatter_upd_2) { + + NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 2,5 }, nd4j::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); + NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, nd4j::DataType::FLOAT32); + + x.linspace(1); + + nd4j::ops::scatter_upd op; + auto result = op.execute({ &x, &indices, &updates }, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests16, scatter_upd_3) { + + NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 20,5 }, nd4j::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); + NDArray output('c', { 10, 3 }, nd4j::DataType::FLOAT32); + + nd4j::ops::scatter_upd op; + ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); +} + +TEST_F(DeclarableOpsTests16, test_size_dtype_1) { + auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); + + nd4j::ops::size op; + auto status = op.execute({ &x }, { &z }, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_1) { + auto z = NDArrayFactory::empty(); + + nd4j::ops::noop op; + auto status = op.execute({}, { &z }, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_2) { + auto z = NDArrayFactory::empty(); + + Context ctx(1); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::noop op; + auto status = op.execute(&ctx); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_svd_1) { + auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); + auto z = NDArrayFactory::create('c', { 3 }); + + nd4j::ops::svd op; + auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { + auto x = NDArrayFactory::create({ 37, 37, 37 }); + auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); + auto e = NDArrayFactory::create(18); + + nd4j::ops::bits_hamming_distance op; + auto result = op.execute({ &x, &y }, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { + auto input = NDArrayFactory::create('c', { 512 }); + auto low = NDArrayFactory::create('c', { 512 }); + auto high = NDArrayFactory::create('c', { 512 }); + + auto output = NDArrayFactory::create(0.0f); + + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); + + nd4j::ops::knn_mindistance op; + auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests16, test_empty_cast_1) { + auto x = NDArrayFactory::create('c', { 1, 0, 2 }); + auto e = NDArrayFactory::create('c', { 1, 0, 2 }); + + nd4j::ops::cast op; + auto result = op.execute({ &x }, {}, { 10 }); + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_EQ(e, *result->at(0)); + + delete result; +} + +TEST_F(DeclarableOpsTests16, test_range_1) { + nd4j::ops::range op; + auto z = NDArrayFactory::create('c', { 200 }); + + Context ctx(1); + ctx.setTArguments({ -1.0, 1.0, 0.01 }); + ctx.setOutputArray(0, &z); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_range_2) { + nd4j::ops::range op; + auto z = NDArrayFactory::create('c', { 200 }); + + double tArgs[] = { -1.0, 1.0, 0.01 }; + + auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); + shape::printShapeInfoLinear("Result", shapes->at(0)); + ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); + + delete shapes; +} + +TEST_F(DeclarableOpsTests16, test_reverse_1) { + std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; + std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; + + for (auto r : rows) { + for (auto c : columns) { + //nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', { r, c }); + auto exp = NDArrayFactory::create('c', { r, c }); + auto reversed = NDArrayFactory::create('c', { r, c }); + + auto rowOriginal = NDArrayFactory::create('c', { c }); + auto rowReversed = NDArrayFactory::create('c', { c }); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float)e); + rowReversed.p(c - e - 1, (float)e); + } + + + auto listI = array.allTensorsAlongDimension({ 1 }); + auto listE = exp.allTensorsAlongDimension({ 1 }); + + for (int e = 0; e < r; e++) { + listI.at(e)->assign(rowOriginal); + listE.at(e)->assign(rowReversed); + } + + nd4j::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); + } + } +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { + /* + test case generated by python colorsys and scaled to suit our needs + from colorsys import * + from random import * + import numpy as np + rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) + hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) + rgbs.ravel() + hsvs.ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f, + 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, + 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f, + 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f, + 0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f, + 0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, + 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f, + 0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f, + 0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f, + 0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, + 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, + 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, + 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, + 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, + 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, + 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, + 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, + 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, + 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { + /* + swapped_rgbs=rgbs.swapaxes(1,2).ravel() + swapped_hsvs=hsvs.swapaxes(1,2).ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, + 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, + 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, + 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, + 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, + 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, + 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, + 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f + }); + + + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { + + auto rgbs = NDArrayFactory::create('c', { 4, 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f + }); + + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { + auto rgbs = NDArrayFactory::create('c', { 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + + //get subarray + //get subarray + NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrRgbs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrRgbs->printShapeInfo("subArrRgbs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrRgbs); + ctx.setOutputArray(0, &actual); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { + + auto hsvs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, + 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, + 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, + 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, + 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, + 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, + 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, + 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, + 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, + 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, + 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, + 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, + 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, + 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, + 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, + 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, + 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, + 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, + 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { + auto hsvs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f, + 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, + 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f, + 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f, + 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f, + 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, + 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f, + 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f, + 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f, + 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f, + 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, + 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f, + 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f, + 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f, + 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f, + 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f, + 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f, + 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f + }); + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { + auto hsvs = NDArrayFactory::create('c', { 4, 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f + }); + auto actual = NDArrayFactory::create('c', { 4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f + }); + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { + + auto hsvs = NDArrayFactory::create('c', { 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { + + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f + }); + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + //get subarray + NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrHsvs.reshapei({ 3 }); + NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrHsvs->printShapeInfo("subArrHsvs"); +#endif + + Context ctx(1); + ctx.setInputArray(0, &subArrHsvs); + ctx.setOutputArray(0, &actual); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index a8377b429..e4d0db62c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -47,7 +47,7 @@ TEST_F(DeclarableOpsTests2, gather_1) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests2, gather_2) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -90,7 +90,7 @@ TEST_F(DeclarableOpsTests2, gather_3) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -110,7 +110,7 @@ TEST_F(DeclarableOpsTests2, gather_4) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -131,7 +131,7 @@ TEST_F(DeclarableOpsTests2, gather_5) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -153,7 +153,7 @@ TEST_F(DeclarableOpsTests2, gather_6) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests2, gather_7) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -197,7 +197,7 @@ TEST_F(DeclarableOpsTests2, gather_8) { // output->printShapeInfo(); // output->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -300,7 +300,7 @@ TEST_F(DeclarableOpsTests2, gather_13) { auto* output = result->at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; @@ -440,7 +440,7 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); - auto exp = x.dup(); + auto exp = new NDArray(x.dup()); nd4j::ops::squeeze op; auto result = op.execute({&x}, {}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 5322a0a6d..dacfac127 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { auto result0 = op.execute({&x}, {0.}, {}); auto z0 = result0->at(0); - auto exp0 = x.reduceAlongDims(reduce::NormFrobenius, empty, false, false); + auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.equalsTo(z0)); @@ -201,7 +201,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { ASSERT_EQ(result1->status(), ND4J_STATUS_OK); auto z1 = result1->at(0); // z1->printIndexedBuffer("Z1"); - auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); // exp1.printIndexedBuffer("EXP1"); // z1->printShapeInfo("Z1 shape"); // exp1.printShapeInfo("EXP1 shape"); @@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { auto result4 = op.execute({&x}, {4.}, {1}); auto z4 = result4->at(0); - auto exp4= x.reduceAlongDims(reduce::NormMax, dims, false, false); + auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); ASSERT_TRUE(exp4.isSameShape(z4)); ASSERT_TRUE(exp4.equalsTo(z4)); @@ -233,7 +233,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { auto result0 = op.execute({&x}, {0}, {}); auto z0 = result0->at(0); - auto exp0 = x.reduceAlongDims(reduce::NormFrobenius, empty, false, false); + auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.equalsTo(z0)); @@ -242,7 +242,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { auto result1 = op.execute({&x, &axis}, {1}, {}); auto z1 = result1->at(0); - auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); ASSERT_TRUE(exp1.isSameShape(z1)); ASSERT_TRUE(exp1.equalsTo(z1)); @@ -251,7 +251,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { auto result4 = op.execute({&x, &axis}, {4}, {}); auto z4 = result4->at(0); - auto exp4= x.reduceAlongDims(reduce::NormMax, dims, false, false); + auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); ASSERT_TRUE(exp4.isSameShape(z4)); ASSERT_TRUE(exp4.equalsTo(z4)); @@ -329,21 +329,21 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { x.linspace(100.); - auto xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true); + auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); x /= xNorm1; - xNorm1 = x.reduceAlongDims(reduce::Norm2,{1}, true); + xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true); ASSERT_TRUE(unities.isSameShape(xNorm1)); ASSERT_TRUE(unities.equalsTo(xNorm1)); x *= scale; - xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true); + xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); nd4j::ops::clipbynorm op; auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE); auto z = result->at(0); - auto zNorm1 = z->reduceAlongDims(reduce::Norm2, {1}, true); + auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); ASSERT_TRUE(exp.isSameShape(&zNorm1)); @@ -2432,17 +2432,11 @@ TEST_F(DeclarableOpsTests3, svd_test6) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test7) { - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 - ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 - ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 - ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 - ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 - ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 - ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, - 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, - 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, - 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); + auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 + ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 + ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); + auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,38.18412, 31.52287, 23.52755, 11.79484, 1.90195, + 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 0, 16}); @@ -2623,75 +2617,25 @@ TEST_F(DeclarableOpsTests3, svd_test9) { 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, - 0.26329, 0.3079 , 0.38582, 0.77696, 0.28872, - 0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, - -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, - -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677, - 0.52823, 0.31163, -0.78777, 0.02322, -0.05234, - -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, - 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, - 0.62118, -0.37993, 0.30992, 0.34537, -0.50444, - 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912, - -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, - -0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989, - -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, - 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, - 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492, - 0.68875, 0.1822 , -0.08046, -0.39238, -0.57619, - 0.34555, 0.12488, -0.50703, -0.29269, 0.72267, - -0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, + -0.61335, 0.10076, 0.01381, 0.40922, -0.66783,-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234, + -0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318,0.62118, -0.37993, 0.30992, 0.34537, -0.50444, + 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, + -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, + -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, - -1.10690000e-01, 1.37280000e-01, - 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03, - -1.00090000e-01, 9.35890000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01, - 6.70320000e-01, 2.10040000e-01, - 1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01, - -4.38680000e-01, 1.83200000e-02, - -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, - -2.10060000e-01, 2.41550000e-01, - -4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02, - -5.40210000e-01, -4.97000000e-02, - -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01, - -4.21850000e-01, 4.00490000e-01, - 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01, - -2.06580000e-01, 7.68890000e-01, - -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, - 5.88210000e-01, 7.12900000e-02, - 2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, - 1.51570000e-01, 6.02100000e-02, - 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02, - -5.79750000e-01, -2.92870000e-01, - 4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01, - 2.72560000e-01, 3.92350000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, - -1.17970000e-01, -4.08100000e-02, - 4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01, - 1.93050000e-01, -6.83340000e-01, - 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01, - 2.37500000e-02, 5.78250000e-01, - -6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02, - -9.19220000e-01, -2.15420000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, - -3.19580000e-01, 2.92020000e-01, - 2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01, - 3.39100000e-02, 2.55590000e-01, - -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01, - 4.98470000e-01, -3.65370000e-01, - 6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02, - -7.30460000e-01, -1.09390000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, - 5.20000000e-04, 1.90420000e-01, - 2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01, - -3.57600000e-02, -8.60450000e-01, - 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, - -4.39400000e-02, 2.17750000e-01, - -6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01, - -4.63400000e-01, -1.74620000e-01}); + auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01, 1.37280000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, 9.35890000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01, 2.10040000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01, 1.83200000e-02, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,-2.10060000e-01, 2.41550000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01, -4.97000000e-02, + -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, 4.00490000e-01,1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01, 7.68890000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01, 7.12900000e-02,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,1.51570000e-01, 6.02100000e-02, + 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01, -2.92870000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, 3.92350000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01, -4.08100000e-02,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01, -6.83340000e-01, + 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02, 5.78250000e-01,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, -2.15420000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01, 2.92020000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02, 2.55590000e-01, + -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01, -3.65370000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, -1.09390000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04, 1.90420000e-01,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02, -8.60450000e-01, + 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {1, 1, 16}); @@ -2736,75 +2680,21 @@ TEST_F(DeclarableOpsTests3, svd_test10) { 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, - 0.26329, 0.3079 , 0.38582, 0.77696, 0.28872, - 0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, - -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, - -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677, - 0.52823, 0.31163, -0.78777, 0.02322, -0.05234, - -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, - 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, - 0.62118, -0.37993, 0.30992, 0.34537, -0.50444, - 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912, - -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, - -0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989, - -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, - 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, - 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492, - 0.68875, 0.1822 , -0.08046, -0.39238, -0.57619, - 0.34555, 0.12488, -0.50703, -0.29269, 0.72267, - -0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,-0.61335, 0.10076, 0.01381, 0.40922, -0.66783, + -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234,-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318, + 0.62118, -0.37993, 0.30992, 0.34537, -0.50444,0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, + -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658,0.49473, -0.39923, -0.20758, 0.74339, -0.01213, + -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, - -1.10690000e-01, - 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03, - -1.00090000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01, - 6.70320000e-01, - 1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01, - -4.38680000e-01, - -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, - -2.10060000e-01, - -4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02, - -5.40210000e-01, - -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01, - -4.21850000e-01, - 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01, - -2.06580000e-01, - -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, - 5.88210000e-01, - 2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, - 1.51570000e-01, - 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02, - -5.79750000e-01, - 4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01, - 2.72560000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, - -1.17970000e-01, - 4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01, - 1.93050000e-01, - 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01, - 2.37500000e-02, - -6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02, - -9.19220000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, - -3.19580000e-01, - 2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01, - 3.39100000e-02, - -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01, - 4.98470000e-01, - 6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02, - -7.30460000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, - 5.20000000e-04, - 2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01, - -3.57600000e-02, - 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, - -4.39400000e-02, - -6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01, - -4.63400000e-01}); + auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01,-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01,-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, + 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01,-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, + 1.51570000e-01,1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01,8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02,-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02,1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, + -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 1, 16}); @@ -2865,8 +2755,36 @@ TEST_F(DeclarableOpsTests3, svd_test11) { ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expS.equalsTo(s)); - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); + + if(nd4j::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + } + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test12) { + + NDArray x('c', {4,3}, {1.7787856,0.80119777,0.72437465,0.23089433,1.7271413,0.18039072,0.50563407,0.89252293,1.5461209,0.92336726,0.085571885,0.79378015}); + NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); + + nd4j::ops::svd op; + auto results = op.execute({&x}, {}, {1, 0, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto s = results->at(0); + + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); delete results; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index c30ad5f89..6d85feec1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -423,7 +423,7 @@ TEST_F(DeclarableOpsTests5, Log1p_test1) { // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); // auto exp = NDArrayFactory::create('c', {3,3}); nd4j::ops::Log1p op; - y.applyTransform(nd4j::transform::Log, nullptr, nullptr); + y.applyTransform(nd4j::transform::Log, y); auto result = op.execute({&matrix}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2737,7 +2737,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); auto res = NDArrayFactory::create('c', {2, 2, 2}); - input.applyScalar(nd4j::scalar::ELU, 1.f, &res); + input.applyScalar(nd4j::scalar::ELU, 1.f, res); ASSERT_TRUE(res.equalsTo(&exp)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 67cd56d5e..c52191b8a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { auto ones = onesRes->at(0); *ones *= 10; - auto onesD = ones->dup(); + auto onesD = new NDArray(ones->dup()); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, onesD); @@ -1577,31 +1577,31 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { TEST_F(DeclarableOpsTests6, MatrixInverse_1) { auto x = NDArrayFactory::create('c', {2, 5, 5}, { - 2.f, 4.f, 60.f, 8.f, 10.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 2.f, 4.f, 6.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 4.f, + 2.f, 4.f, 60.f, 8.f, 10.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 2.f, 4.f, 6.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 4.f, - 1.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 4.f, 3.f, 2.f, 1.f, 0.f, + 1.f, 0.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f }); auto exp = NDArrayFactory::create('c', {2, 5, 5}, { - 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, - 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, + 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, + 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, - 0.f, 0.f, 0.f, 0.f, 0.25f, + 0.f, 0.f, 0.f, 0.f, 0.25f, - 1.0f, 0.0f, 0.0f, 0.0f, 0.f, - -2.0f, 1.0f, 0.f, 0.f, 0.f, + 1.0f, 0.0f, 0.0f, 0.0f, 0.f, + -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, - -27.0f, 0.0f, 1.0f, -2.0f, 1.f, + -27.0f, 0.0f, 1.0f, -2.0f, 1.f, }); nd4j::ops::matrix_inverse op; @@ -1891,10 +1891,8 @@ TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { std::vector dims = {0, 1}; - auto z = x.applyReduce3(reduce3::CosineSimilarity, &y, dims, nullptr); - ASSERT_TRUE(z != nullptr); - - delete z; + auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); + ASSERT_TRUE(&z != nullptr); } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index ebe1f8e18..ef495142d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -3029,13 +3029,13 @@ TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { // auto expDeviance = NDArrayFactory::create('c', {10, 10}); auto squared = NDArrayFactory::create('c', {10, 10}); - data.applyTransform(transform::Square, &squared, nullptr); + data.applyTransform(transform::Square, squared); auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); // ssSquared->printBuffer("Sum squared"); // squared.printBuffer("Squared"); nd4j::ops::normalize_moments op; - auto results = op.execute({&counts, means, ssSquared}, {0.0}, {0}); - (*means) /= counts; + auto results = op.execute({&counts, &means, &ssSquared}, {0.0}, {0}); + means /= counts; // nd4j::ops::normalize_moments op; // auto results = op.execute({&counts, means, deviance}, {0.0}, {}); @@ -3049,13 +3049,11 @@ TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { // outputDeviance->printIndexedBuffer("Variance"); // deviance.printIndexedBuffer("Expected"); // means->printIndexedBuffer("Expected means"); - ASSERT_TRUE(means->isSameShape(outputMeans)); - ASSERT_TRUE(means->equalsTo(outputMeans)); + ASSERT_TRUE(means.isSameShape(outputMeans)); + ASSERT_TRUE(means.equalsTo(outputMeans)); ASSERT_TRUE(deviance.isSameShape(outputDeviance)); ASSERT_TRUE(deviance.equalsTo(outputDeviance)); - delete means; //delete deviance; - delete ssSquared; // ASSERT_TRUE(expMeans.isSameShape(outputMeans)); // ASSERT_TRUE(expMeans.equalsTo(outputMeans)); // ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); @@ -3636,60 +3634,60 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); x.linspace(1); - auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }); // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index dfbfc90a8..6df52fb54 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -228,7 +228,7 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) { auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); - x.applyScalar(scalar::Add, 1.0, &z); + x.applyScalar(scalar::Add, 1.0, z); ASSERT_EQ(e, z); } @@ -634,10 +634,7 @@ TEST_F(DeclarableOpsTests9, concat_test18) { for (int e = 0; e < 2000; e++) { auto row = z.tensorAlongDimension(e, {1}); - - ASSERT_NEAR((float) e, row->e(0), 1e-5f); - - delete row; + ASSERT_NEAR((float) e, row.e(0), 1e-5f); } } @@ -1684,7 +1681,7 @@ TEST_F(DeclarableOpsTests9, test_broadcast_bool_1) { auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, &y, &z, nullptr); + x.applyBroadcast(broadcast::LessThan, dims, y, z); } TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { @@ -1697,7 +1694,7 @@ TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, &y, &z, nullptr); + x.applyBroadcast(broadcast::LessThan, dims, y, z); } @@ -1746,7 +1743,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) { auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); auto expect = NDArrayFactory::create('c', {bS, nOut}); - auto norm2 = x.reduceAlongDims(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] + auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] auto y = ( (x / norm2) * clip) * colVect ; auto temp = (x / norm2) * clip; @@ -2927,13 +2924,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -2970,13 +2967,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3012,13 +3009,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3051,13 +3048,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3092,13 +3089,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3133,13 +3130,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { auto dLdG = results->at(3); auto dLdB = results->at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3179,13 +3176,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; @@ -3224,13 +3221,13 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index 8ae123260..ca1479210 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -197,7 +197,7 @@ TEST_F(EmptyTests, Test_Reshape_3) { TEST_F(EmptyTests, Test_dup_1) { auto empty = NDArrayFactory::empty(); - auto dup = empty.dup(); + auto dup = new NDArray(empty.dup()); ASSERT_TRUE(dup->isEmpty()); ASSERT_EQ(empty, *dup); diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 085127e74..0bf9a1eb7 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -69,8 +69,7 @@ TEST_F(HelpersTests1, evalHHmatrix_test1) { auto exp = NDArrayFactory::create('c', {4,4}, {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, -0.0632377, -0.0210792, -0.13484,-0.0632377, 0.98884,-0.00371987, -0.0449467,-0.0210792,-0.00371987, 0.99876}); auto result = ops::helpers::Householder::evalHHmatrix(x); - - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -86,7 +85,7 @@ TEST_F(HelpersTests1, evalHHmatrix_test2) { auto result = ops::helpers::Householder::evalHHmatrix(x); - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -109,7 +108,7 @@ TEST_F(HelpersTests1, evalHHmatrixData_test1) { ASSERT_NEAR(normX, normXExpected, 1e-5); ASSERT_NEAR(coeff, coeffExpected, 1e-5); - ASSERT_TRUE(tail.isSameShapeStrict(&expTail)); + ASSERT_TRUE(tail.isSameShapeStrict(expTail)); ASSERT_TRUE(tail.equalsTo(&expTail)); } @@ -128,7 +127,7 @@ TEST_F(HelpersTests1, Householder_mulLeft_test1) { ops::helpers::Householder::mulLeft(x, tail, 0.1); // expTail.printShapeInfo(); - ASSERT_TRUE(x.isSameShapeStrict(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -145,7 +144,7 @@ TEST_F(HelpersTests1, Householder_mulLeft_test2) { ops::helpers::Householder::mulLeft(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -162,7 +161,7 @@ TEST_F(HelpersTests1, Householder_mulRight_test1) { ops::helpers::Householder::mulRight(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(&exp)); + ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -181,9 +180,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(&object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(&object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } @@ -200,9 +199,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(&object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(&object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } @@ -219,9 +218,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(&object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(&object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } @@ -241,8 +240,8 @@ TEST_F(HelpersTests1, HHsequence_test1) { ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); @@ -268,8 +267,8 @@ TEST_F(HelpersTests1, HHsequence_test2) { ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); @@ -295,8 +294,8 @@ TEST_F(HelpersTests1, HHsequence_test3) { ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(&vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); @@ -870,9 +869,9 @@ TEST_F(HelpersTests1, SVD_test12) { ASSERT_TRUE(expU.equalsTo(&U)); ASSERT_TRUE(expV.equalsTo(&V)); - ASSERT_TRUE(expSingVals.isSameShapeStrict(&singVals)); - ASSERT_TRUE(expU.isSameShapeStrict(&U)); - ASSERT_TRUE(expV.isSameShapeStrict(&V)); + ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); + ASSERT_TRUE(expU.isSameShapeStrict(U)); + ASSERT_TRUE(expV.isSameShapeStrict(V)); } /////////////////////////////////////////////////////////////////// @@ -893,9 +892,9 @@ TEST_F(HelpersTests1, SVD_test13) { ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(&qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(&qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -917,9 +916,9 @@ TEST_F(HelpersTests1, SVD_test14) { ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(&qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(&qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -941,9 +940,9 @@ TEST_F(HelpersTests1, SVD_test15) { ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(&qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(&qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -1246,9 +1245,9 @@ TEST_F(HelpersTests1, SVD_test16) { svd.DivideAndConquer(0, 3, 1, 1, 1); // svd._m.printIndexedBuffer(); - ASSERT_TRUE(expM.isSameShapeStrict(&svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); ASSERT_TRUE(expM.equalsTo(&svd._m)); ASSERT_TRUE(expU.equalsTo(&svd._u)); @@ -1281,9 +1280,9 @@ TEST_F(HelpersTests1, SVD_test17) { ASSERT_TRUE(expU.equalsTo(&svd._u)); ASSERT_TRUE(expV.equalsTo(&svd._v)); - ASSERT_TRUE(expM.isSameShapeStrict(&svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); } // /////////////////////////////////////////////////////////////////// @@ -1329,9 +1328,9 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expU.equalsTo(&svd._u)); // ASSERT_TRUE(expV.equalsTo(&svd._v)); -// ASSERT_TRUE(expS.isSameShapeStrict(&svd._s)); -// ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); -// ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } @@ -1378,9 +1377,9 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expU.equalsTo(&svd._u)); // ASSERT_TRUE(expV.equalsTo(&svd._v)); -// ASSERT_TRUE(expS.isSameShapeStrict(&svd._s)); -// ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); -// ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } @@ -1427,9 +1426,9 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expU.equalsTo(&svd._u)); // ASSERT_TRUE(expV.equalsTo(&svd._v)); -// ASSERT_TRUE(expS.isSameShapeStrict(&svd._s)); -// ASSERT_TRUE(expU.isSameShapeStrict(&svd._u)); -// ASSERT_TRUE(expV.isSameShapeStrict(&svd._v)); +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } @@ -1444,7 +1443,7 @@ TEST_F(HelpersTests1, SVD_test17) { // ops::helpers::reverseArray(nd4j::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.getShapeInfo(), outArr.getBuffer(), outArr.getShapeInfo()); // // ASSERT_TRUE(outArr.equalsTo(&exp)); -// ASSERT_TRUE(outArr.isSameShapeStrict(&exp)); +// ASSERT_TRUE(outArr.isSameShapeStrict(exp)); //} // // @@ -1458,7 +1457,7 @@ TEST_F(HelpersTests1, SVD_test17) { // ops::helpers::reverseArray(nd4j::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.getShapeInfo(), inArr.getBuffer(), inArr.getShapeInfo()); // // ASSERT_TRUE(inArr.equalsTo(&exp)); -// ASSERT_TRUE(inArr.isSameShapeStrict(&exp)); +// ASSERT_TRUE(inArr.isSameShapeStrict(exp)); //} // // @@ -1472,7 +1471,7 @@ TEST_F(HelpersTests1, SVD_test17) { // ops::helpers::reverseArray(nd4j::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.getShapeInfo(), outArr.getBuffer(), outArr.getShapeInfo(), 5); // // ASSERT_TRUE(outArr.equalsTo(&exp)); -// ASSERT_TRUE(outArr.isSameShapeStrict(&exp)); +// ASSERT_TRUE(outArr.isSameShapeStrict(exp)); //} /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp index 96c480fd9..790279f74 100644 --- a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp @@ -42,18 +42,10 @@ TEST_F(IndexingTests, StridedSlice_1) { auto begin = NDArrayFactory::create({2,2, 0}); auto end = NDArrayFactory::create({3,3,3}); auto strides = NDArrayFactory::create({1,1,1}); - //nd4j_debug("print x->rankOf(): %i", x.rankOf()); - /* - auto tads = x.allTensorsAlongDimension({0}); - nd4j_debug("numTads: %i\n", tads->size()); - for (int e = 0; e < tads->size(); e++) - tads->at(e)->assign((float) e); - */ nd4j::ops::strided_slice op; -// auto result = op.execute({&x}, {}, {0,0,0,0,0, 2,2,0, 3,3,3, 1,1,1}); auto result = op.execute({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -202,8 +194,8 @@ TEST_F(IndexingTests, SimpleSlice_4) { TEST_F(IndexingTests, MaskedSlice_0) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {1, 5}); @@ -222,15 +214,14 @@ TEST_F(IndexingTests, MaskedSlice_0) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(IndexingTests, MaskedSlice_00) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); @@ -243,21 +234,18 @@ TEST_F(IndexingTests, MaskedSlice_00) { auto z = result->at(0); - // z->printShapeInfo("z"); - ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(IndexingTests, MaskedSlice_1) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {5}); @@ -276,7 +264,6 @@ TEST_F(IndexingTests, MaskedSlice_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(IndexingTests, MaskedSlice_2) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index e7f7f7e68..0e8db97ff 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1230,18 +1230,18 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) { TEST_F(JavaInteropTests, test_ismax_view) { auto original = NDArrayFactory::create('c', {2, 3, 40}); auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); - v->assign(1.0); + v.assign(1.0); - auto e = v->like(); + auto e = v.like(); auto t = e.tensorAlongDimension(0, {0, 1}); - t->assign(1.0); + t.assign(1.0); - auto z = v->ulike(); + auto z = v.ulike(); Nd4jLong iArgs[] = {2L, 0L}; Context ctx(1); - ctx.setInputArray(0, v->buffer(), v->shapeInfo(), v->specialBuffer(), v->specialShapeInfo()); + ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo()); ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); ctx.setIArguments(iArgs, 1); @@ -1249,9 +1249,6 @@ TEST_F(JavaInteropTests, test_ismax_view) { op.execute(&ctx); ASSERT_EQ(e, z); - - delete v; - delete t; } TEST_F(JavaInteropTests, test_size_dtype_1) { diff --git a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu index 30244b7dc..5bf8c8b57 100644 --- a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu @@ -75,7 +75,7 @@ void test(NDArray &x) { return x+1.; }; - x.applyLambda(f, &x); + x.applyLambda(f, x); } template @@ -84,7 +84,7 @@ void test2(NDArray &x) { return x+1.; }; - x.applyLambda(f, &x); + x.applyLambda(f, x); } void testPairwise(NDArray &x, NDArray &y) { @@ -92,7 +92,7 @@ void testPairwise(NDArray &x, NDArray &y) { return x + y +1.; }; - x.applyPairwiseLambda(&y, f, &x); + x.applyPairwiseLambda(y, f, x); } void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { @@ -100,7 +100,7 @@ void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { return i + j + k + 2.; }; - i.applyTriplewiseLambda(&j, &k, f, &i); + i.applyTriplewiseLambda(j, k, f, i); } void testIndexed(NDArray &x) { @@ -108,7 +108,7 @@ void testIndexed(NDArray &x) { return _idx + 1.; }; - x.applyIndexedLambda(f, &x); + x.applyIndexedLambda(f, x); } void testIndexedPairwise(NDArray &x, NDArray &y) { @@ -116,7 +116,7 @@ void testIndexedPairwise(NDArray &x, NDArray &y) { return _idx + x + y +1.; }; - x.applyIndexedPairwiseLambda(&y, f, &x); + x.applyIndexedPairwiseLambda(y, f, x); } TEST_F(LambdaTests, test_basic_2) { @@ -197,7 +197,7 @@ void testPairwiseMy(NDArray &x, NDArray &y, NDArray &z) { + nd4j::math::nd4j_exp(-nd4j::math::nd4j_abs(x))); }; - x.applyPairwiseLambda(&y, f, &z); + x.applyPairwiseLambda(y, f, z); } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index ffcd5759e..cb4d4d07d 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -194,11 +194,10 @@ TEST_F(LegacyOpsTests, ReduceTests_2) { auto exp = x.reduceAlongDimension(reduce::Sum, {1}); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete exp; } @@ -211,7 +210,7 @@ TEST_F(LegacyOpsTests, ReduceTests_3) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); auto result = op.execute({&x, &indices}, {}, {}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Sum,{1}); + auto exp = x.reduceAlongDimension(reduce::Sum,{1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -231,7 +230,7 @@ TEST_F(LegacyOpsTests, ReduceTests_4) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); auto result = op.execute({&x, &indices}, {}, {}, {true}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Sum, {1}, true); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); // indices.printShapeInfo("Indices shape"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); // z->printIndexedBuffer("Output reduce 4"); @@ -275,11 +274,10 @@ TEST_F(LegacyOpsTests, ReduceTests_6) { auto exp = x.reduceAlongDimension(reduce::Mean, {1}); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete exp; } @@ -292,7 +290,7 @@ TEST_F(LegacyOpsTests, ReduceTests_7) { nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); auto result = op.execute({&x, &indices}, {}, {}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Mean,{1}); + auto exp = x.reduceAlongDimension(reduce::Mean,{1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -312,7 +310,7 @@ TEST_F(LegacyOpsTests, ReduceTests_8) { nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); auto result = op.execute({&x, &indices}, {}, {}, {true}); auto z = result->at(0); - auto exp = x.reduceAlongDims(reduce::Mean, {1}, true); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); ASSERT_EQ(ND4J_STATUS_OK, result->status()); // z->printIndexedBuffer("Reduce8 output"); @@ -382,10 +380,8 @@ TEST_F(LegacyOpsTests, BroadcastingTests_1) { auto list = x.allTensorsAlongDimension({1}); // x.printIndexedBuffer("Output broadcast"); // list->at(0)->printIndexedBuffer("Column 0:"); - for (int e = 0; e < list->size(); e++) - ASSERT_TRUE(row.equalsTo(list->at(e))); - - delete list; + for (int e = 0; e < list.size(); e++) + ASSERT_TRUE(row.equalsTo(list.at(e))); } TEST_F(LegacyOpsTests, BroadcastingTests_2) { @@ -417,7 +413,7 @@ TEST_F(LegacyOpsTests, PowDerivative_1) { float p = 2.0f; - x.applyScalar(scalar::PowDerivative, p); + x.applyScalar(scalar::PowDerivative, p, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -661,10 +657,10 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { e.assign(false); auto row = y.tensorAlongDimension(1, {1}); - row->assign(2.0f); + row.assign(2.0f); auto erow = e.tensorAlongDimension(1, {1}); - erow->assign(true); + erow.assign(true); auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1); @@ -680,9 +676,6 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); ASSERT_EQ(e, z); - - delete row; - delete erow; } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index e506839df..625d9978f 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -59,7 +59,7 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); list.write(e, row); - tads->at(e)->assign(row); + tads.at(e)->assign(row); } nd4j::ops::stack_list op; @@ -75,7 +75,6 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { @@ -86,7 +85,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); //list.write(e, row); - tads->at(e)->assign(row); + tads.at(e)->assign(row); delete row; } @@ -103,13 +102,12 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { // ASSERT_TRUE(exp.equalsTo(z)); for (int e = 0; e < 10; e++) { auto row = list.read(e); - ASSERT_TRUE(row->equalsTo(tads->at(e))); + ASSERT_TRUE(row->equalsTo(tads.at(e))); //list.write(e, row); delete row; } delete result; - delete tads; } //TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { @@ -153,7 +151,7 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {1, 100}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } @@ -179,16 +177,16 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } auto tads = exp.allTensorsAlongDimension({1}); - tads->at(0)->assign(1.0f); - tads->at(1)->assign(1.0f); - tads->at(2)->assign(3.0f); - tads->at(3)->assign(3.0f); + tads.at(0)->assign(1.0f); + tads.at(1)->assign(1.0f); + tads.at(2)->assign(3.0f); + tads.at(3)->assign(3.0f); nd4j::ops::pick_list op; @@ -202,7 +200,6 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(ListOperationsTests, BasicTest_Size_1) { @@ -211,7 +208,7 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } @@ -272,14 +269,14 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {5}); row->assign((double) e); - tads->at(e)->assign(row); + tads.at(e)->assign(row); if (e < 2) - tads0->at(cnt0++)->assign(row); + tads0.at(cnt0++)->assign(row); else if (e < 5) - tads1->at(cnt1++)->assign(row); + tads1.at(cnt1++)->assign(row); else - tads2->at(cnt2++)->assign(row); + tads2.at(cnt2++)->assign(row); delete row; } @@ -300,10 +297,6 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); delete result; - delete tads; - delete tads0; - delete tads1; - delete tads2; } TEST_F(ListOperationsTests, BasicTest_Scatter_1) { @@ -315,7 +308,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {1, 5}); row->assign((double) e); - tads->at(e)->assign(row); + tads.at(e)->assign(row); delete row; } @@ -329,15 +322,13 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); for (int e = 0; e < 10; e++) { - auto row = tads->at(9 - e); + auto row = tads.at(9 - e); auto chunk = list.readRaw(e); ASSERT_TRUE(chunk->isSameShape(row)); ASSERT_TRUE(chunk->equalsTo(row)); } - - delete tads; delete result; } @@ -376,7 +367,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {3}); row->assign((double) e); - list.write(e, row->dup()); + list.write(e, new NDArray(row->dup())); delete row; } @@ -384,7 +375,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { auto exp = NDArrayFactory::create('c', {10, 3}); auto tads = exp.allTensorsAlongDimension({1}); for (int e = 0; e < 10; e++) { - auto tad = tads->at(9 - e); + auto tad = tads.at(9 - e); tad->assign(e); } @@ -407,7 +398,6 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { ASSERT_TRUE(exp.equalsTo(z)); delete result; - delete tads; } TEST_F(ListOperationsTests, GraphTests_Sequential_1) { @@ -415,17 +405,16 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_1) { auto matrix = NDArrayFactory::create_('c', {3, 3}); auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {3, 3}); auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp->at(0)->assign(0.f); - tadsExp->at(1)->assign(-1.f); - tadsExp->at(2)->assign(-2.f); - delete tadsExp; + tadsExp.at(0)->assign(0.f); + tadsExp.at(1)->assign(-1.f); + tadsExp.at(2)->assign(-2.f); auto indices = NDArrayFactory::valueOf({3}, 1, 'c'); //indices->linspace(0); @@ -472,7 +461,7 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_1) { // nodeF1->setCustomOp(&opF); // nodeF2->setCustomOp(&opF); - // now we're stacking chunks back to matrix state + // now we're stacking chunks back to matrix state nd4j::ops::stack_list opG; auto nodeG = new Node(&opG, 20, {2, 15, 16, 17}); //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); @@ -537,8 +526,6 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_1) { ASSERT_TRUE(exp.isSameShape(stack)); ASSERT_TRUE(exp.equalsTo(stack)); - - delete tads; } @@ -548,16 +535,15 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { auto scalar = NDArrayFactory::create_(0.0f); auto matrix = NDArrayFactory::create_('c', {3, 3}); auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - tads->at(e)->assign((float) (e+1)); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); } - auto exp = NDArrayFactory::create('c', {3, 3}); auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp->at(0)->assign(0.f); - tadsExp->at(1)->assign(-1.f); - tadsExp->at(2)->assign(-2.f); + tadsExp.at(0)->assign(0.f); + tadsExp.at(1)->assign(-1.f); + tadsExp.at(2)->assign(-2.f); //auto indices = NDArray::valueOf({1, 3}, 1.0f, 'c'); auto indices = NDArrayFactory::create_('c', {1, 3}); @@ -580,7 +566,7 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { // filling list with matrix nd4j::ops::scatter_list opC; auto nodeC = new Node(&opC, 3, {2, -2, 1, -3}); - + //nodeC->setCustomOp(&opC); nd4j::ops::read_list opD; @@ -608,7 +594,7 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { // nodeF1->setCustomOp(&opF); // nodeF2->setCustomOp(&opF); - // now we're gathering chunks back to matrix state + // now we're gathering chunks back to matrix state nd4j::ops::pick_list opG; auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17}); //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); @@ -665,14 +651,11 @@ TEST_F(ListOperationsTests, GraphTests_Sequential_2) { ASSERT_EQ(3, list->elements()); ASSERT_TRUE(variableSpace->hasVariable(20)); - + auto stack = variableSpace->getVariable(20)->getNDArray(); - + ASSERT_TRUE(stack != nullptr); ASSERT_TRUE(exp.isSameShape(stack)); ASSERT_TRUE(exp.equalsTo(stack)); - - delete tadsExp; - delete tads; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index d271048a9..9afc34267 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -237,18 +237,14 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { NDArray exp2('c', {1,1}, {1}, nd4j::DataType::INT64); NDArray exp3('c', {2}, {1,2}, nd4j::DataType::INT64); - auto* scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/); - ASSERT_EQ(*scalar1, exp1); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto* scalar2 = x.reduceAlongDimension(nd4j::reduce::CountZero, {}/*whole range*/, true); - ASSERT_EQ(*scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::CountZero, {}/*whole range*/, true); + ASSERT_EQ(scalar2, exp2); - auto* scalar3 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {1}); - ASSERT_EQ(*scalar3, exp3); - - delete scalar1; - delete scalar2; - delete scalar3; + auto scalar3 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {1}); + ASSERT_EQ(scalar3, exp3); } //////////////////////////////////////////////////////////////////////////////// @@ -257,16 +253,13 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { NDArray exp1('c', {}, {1.5}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32); - auto* scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); // scalar1->printShapeInfo(); // scalar1->printIndexedBuffer(); - ASSERT_EQ(*scalar1, exp1); + ASSERT_EQ(scalar1, exp1); - auto* scalar2 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); - ASSERT_EQ(*scalar2, exp2); - - delete scalar1; - delete scalar2; + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// @@ -275,10 +268,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { NDArray exp1('c', {}, {8.}, nd4j::DataType::HALF); NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF); - auto scalar1 = x.reduceAlongDims(nd4j::reduce::Sum, {}/*whole range*/); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Sum, {}/*whole range*/); ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDims(nd4j::reduce::Sum, {1}); + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::Sum, {1}); ASSERT_EQ(scalar2, exp2); } @@ -288,10 +281,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp2('c', {2}, {1,0}, nd4j::DataType::BOOL); - auto scalar1 = x.reduceAlongDims(nd4j::reduce::IsPositive, {}/*whole range*/); + auto scalar1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {}/*whole range*/); ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDims(nd4j::reduce::IsPositive, {1}); + auto scalar2 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {1}); ASSERT_EQ(scalar2, exp2); } @@ -974,22 +967,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformFloat_test1) { NDArray result2('c', {2,2}, nd4j::DataType::DOUBLE); NDArray result3('c', {2,2}, nd4j::DataType::HALF); - x1.applyTransform(nd4j::transform::Sqrt, &result1); + x1.applyTransform(nd4j::transform::Sqrt, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::Sqrt, &result2); + x2.applyTransform(nd4j::transform::Sqrt, result2); ASSERT_EQ(result2, exp2); - x3.applyTransform(nd4j::transform::Sqrt, &result3); + x3.applyTransform(nd4j::transform::Sqrt, result3); ASSERT_EQ(result3, exp3); - x4.applyTransform(nd4j::transform::Sqrt, &result3); + x4.applyTransform(nd4j::transform::Sqrt, result3); ASSERT_EQ(result3, exp4); - x2.applyTransform(nd4j::transform::Sqrt); + x2.applyTransform(nd4j::transform::Sqrt, x2); ASSERT_EQ(x2, exp3); - x3.applyTransform(nd4j::transform::Sqrt); + x3.applyTransform(nd4j::transform::Sqrt, x3); ASSERT_EQ(x3, exp2); } @@ -1016,25 +1009,25 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformSame_test1) { NDArray result4('c', {2,2}, nd4j::DataType::BOOL); NDArray result5('c', {3,2}, nd4j::DataType::DOUBLE); - x1.applyTransform(nd4j::transform::Square, &result1); + x1.applyTransform(nd4j::transform::Square, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::Square, &result2); + x2.applyTransform(nd4j::transform::Square, result2); ASSERT_EQ(result2, exp2); - x3.applyTransform(nd4j::transform::Square, &result3); + x3.applyTransform(nd4j::transform::Square, result3); ASSERT_EQ(result3, exp3); - x4.applyTransform(nd4j::transform::Square, &result4); + x4.applyTransform(nd4j::transform::Square, result4); ASSERT_EQ(result4, exp4); - x2.applyTransform(nd4j::transform::Square); + x2.applyTransform(nd4j::transform::Square, x2); ASSERT_EQ(x2, exp2); - x3.applyTransform(nd4j::transform::Square); + x3.applyTransform(nd4j::transform::Square, x3); ASSERT_EQ(x3, exp3); - x5.applyTransform(nd4j::transform::Square, &result5); + x5.applyTransform(nd4j::transform::Square, result5); ASSERT_EQ(result5, exp5); } @@ -1057,19 +1050,19 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformBool_test1) { NDArray result2('c', {3,2}, nd4j::DataType::BOOL); /* - x1.applyTransform(nd4j::transform::IsMax, &result1); + x1.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::IsMax, &result1); + x2.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp1); - x3.applyTransform(nd4j::transform::IsMax, &result1); + x3.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp1); - x4.applyTransform(nd4j::transform::IsMax, &result1); + x4.applyTransform(nd4j::transform::IsMax, result1); ASSERT_EQ(result1, exp2); - x5.applyTransform(nd4j::transform::IsMax, &result2); + x5.applyTransform(nd4j::transform::IsMax, result2); ASSERT_EQ(result2, exp3); */ } @@ -1095,28 +1088,28 @@ TEST_F(MultiDataTypeTests, ndarray_applyTransformStrict_test1) { NDArray result3('c', {2,2}, nd4j::DataType::DOUBLE); NDArray result4('c', {3,2}, nd4j::DataType::DOUBLE); - x1.applyTransform(nd4j::transform::CubeDerivative, &result1); + x1.applyTransform(nd4j::transform::CubeDerivative, result1); ASSERT_EQ(result1, exp1); - x2.applyTransform(nd4j::transform::CubeDerivative, &result2); + x2.applyTransform(nd4j::transform::CubeDerivative, result2); ASSERT_EQ(result2, exp2); - x3.applyTransform(nd4j::transform::CubeDerivative, &result3); + x3.applyTransform(nd4j::transform::CubeDerivative, result3); ASSERT_EQ(result3, exp3); - x4.applyTransform(nd4j::transform::CubeDerivative, &result4); + x4.applyTransform(nd4j::transform::CubeDerivative, result4); ASSERT_EQ(result4, exp4); - x1.applyTransform(nd4j::transform::CubeDerivative); + x1.applyTransform(nd4j::transform::CubeDerivative, x1); ASSERT_EQ(x1, exp1); - x2.applyTransform(nd4j::transform::CubeDerivative); + x2.applyTransform(nd4j::transform::CubeDerivative, x2); ASSERT_EQ(x2, exp2); - x3.applyTransform(nd4j::transform::CubeDerivative); + x3.applyTransform(nd4j::transform::CubeDerivative, x3); ASSERT_EQ(x3, exp3); - x4.applyTransform(nd4j::transform::CubeDerivative); + x4.applyTransform(nd4j::transform::CubeDerivative, x4); ASSERT_EQ(x4, exp5); } @@ -1138,19 +1131,19 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test1) { NDArray exp4('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, nd4j::DataType::DOUBLE); NDArray exp5('c', {3,2}, {0, 2, 4, 6, 8, 5}, nd4j::DataType::INT32); - x1.applyPairwiseTransform(nd4j::pairwise::Add, &x4, &x5, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::Add, x4, x5); ASSERT_EQ(x5, exp5); - x1.applyPairwiseTransform(nd4j::pairwise::Add, &x4, &x6, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::Add, x4, x6); ASSERT_EQ(x6, exp4); - x1.applyPairwiseTransform(nd4j::pairwise::Add, x4, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::Add, x4); ASSERT_EQ(x1, exp1); - x2.applyPairwiseTransform(nd4j::pairwise::Add, x4, nullptr); + x2.applyPairwiseTransform(nd4j::pairwise::Add, x4); ASSERT_EQ(x2, exp2); - x3.applyPairwiseTransform(nd4j::pairwise::Add, x4, nullptr); + x3.applyPairwiseTransform(nd4j::pairwise::Add, x4); ASSERT_EQ(x3, exp3); } @@ -1173,13 +1166,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test2) { NDArray exp2('c', {2,3}, {1, 0, 1, 1, 0, 1}, nd4j::DataType::BOOL); NDArray exp3('c', {2,3}, {0, 1, 0, 0, 0, 0}, nd4j::DataType::BOOL); - x1.applyPairwiseTransform(nd4j::pairwise::EqualTo, &x2, &x7, nullptr); + x1.applyPairwiseTransform(nd4j::pairwise::EqualTo, x2, x7); ASSERT_EQ(x7, exp1); - x3.applyPairwiseTransform(nd4j::pairwise::EqualTo, &x4, &x8, nullptr); + x3.applyPairwiseTransform(nd4j::pairwise::EqualTo, x4, x8); ASSERT_EQ(x8, exp2); - x5.applyPairwiseTransform(nd4j::pairwise::EqualTo, &x6, &x8, nullptr); + x5.applyPairwiseTransform(nd4j::pairwise::EqualTo, x6, x8); ASSERT_EQ(x8, exp3); } @@ -1199,13 +1192,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test1) { NDArray exp2('c', {2,3}, {11, 21, 31, 42, 52, 62}, nd4j::DataType::FLOAT32); NDArray exp3('c', {2,3}, {11, 21, 31, 41, 51, 61}, nd4j::DataType::INT32); - x1.applyBroadcast(nd4j::broadcast::Add, {0}, &x2, &x3); + x1.applyBroadcast(nd4j::broadcast::Add, {0}, x2, x3); ASSERT_EQ(x3, exp1); - x1.applyBroadcast(nd4j::broadcast::Add, {0}, &x4, &x5); + x1.applyBroadcast(nd4j::broadcast::Add, {0}, x4, x5); ASSERT_EQ(x5, exp2); - x1.applyBroadcast(nd4j::broadcast::Add, {0}, &x6, &x3); + x1.applyBroadcast(nd4j::broadcast::Add, {0}, x6, x3); ASSERT_EQ(x3, exp3); } @@ -1222,10 +1215,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test2) { NDArray exp1('c', {2,3}, {1, 0, 0, 0, 0, 1}, nd4j::DataType::BOOL); NDArray exp2('c', {2,3}, {1, 1, 1, 0, 0, 1}, nd4j::DataType::BOOL); - x1.applyBroadcast(nd4j::broadcast::EqualTo, {0}, &x2, &x3); + x1.applyBroadcast(nd4j::broadcast::EqualTo, {0}, x2, x3); ASSERT_EQ(x3, exp1); - x4.applyBroadcast(nd4j::broadcast::EqualTo, {0}, &x5, &x3); + x4.applyBroadcast(nd4j::broadcast::EqualTo, {0}, x5, x3); ASSERT_EQ(x3, exp2); } @@ -1256,13 +1249,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { NDArray exp4('c', {0}, {4.5}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, nd4j::DataType::DOUBLE); - x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x2, &x3, true); + x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2, x3); ASSERT_EQ(x3, exp1); - x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x4, &x5, true); + x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x4, x5); ASSERT_EQ(x5, exp2); - x6.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x7, &x8, true); + x6.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x7, x8); ASSERT_EQ(x8, exp3); auto x9 = x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2); @@ -1274,17 +1267,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { auto x11 = x6.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x7); ASSERT_EQ(x11, exp3); - auto x12 = x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x2); - ASSERT_EQ(*x12, exp1); - delete x12; + auto x12 = x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x12, exp1); - x13.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x14, &x15, true); + x13.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x14, x15); ASSERT_EQ(x15, exp4); - x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x14, &x16, true); + x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x14, x16); ASSERT_EQ(x16, exp5); - x14.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), &x1, &x16, true); + x14.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x1, x16); ASSERT_EQ(x16, exp5); } @@ -1305,16 +1297,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) { NDArray exp2('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL); NDArray exp3('c', {0}, {0}, nd4j::DataType::BOOL); - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x2, &x3, true); + x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x2, x3); ASSERT_EQ(x3, exp1); - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x4, &x3, true); + x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x4, x3); ASSERT_EQ(x3, exp2); - x4.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x1, &x3, true); + x4.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x1, x3); ASSERT_EQ(x3, exp2); - x5.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), &x4, &x6, true); + x5.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x4, x6); ASSERT_EQ(x6, exp3); } @@ -1334,19 +1326,19 @@ TEST_F(MultiDataTypeTests, ndarray_applyScalar_test1) { NDArray exp4('c', {2,2}, {1.1, 2.1, 1.1, 2.1}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2,2}, {1, 1, 1, 1}, nd4j::DataType::BOOL); - x1.applyScalar(nd4j::scalar::Add, 1); + x1.applyScalar(nd4j::scalar::Add, 1, x1); ASSERT_EQ(x1, exp1); - x1.applyScalar(nd4j::scalar::Add, 0.5, &x3); + x1.applyScalar(nd4j::scalar::Add, 0.5, x3); ASSERT_EQ(x3, exp2); - x2.applyScalar(nd4j::scalar::Add, 0.1); + x2.applyScalar(nd4j::scalar::Add, 0.1, x2); ASSERT_EQ(x2, exp3); - x4.applyScalar(nd4j::scalar::Add, 1.1, &x3); + x4.applyScalar(nd4j::scalar::Add, 1.1, x3); ASSERT_EQ(x3, exp4); - x4.applyScalar(nd4j::scalar::Add, 1); + x4.applyScalar(nd4j::scalar::Add, 1, x4); ASSERT_EQ(x4, exp5); } @@ -1362,13 +1354,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyScalar_test2) { NDArray exp1('c', {2,2}, {0, 1, 0, 0}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {0, 1, 1, 0}, nd4j::DataType::BOOL); - x1.applyScalar(nd4j::scalar::EqualTo, 1, &x4); + x1.applyScalar(nd4j::scalar::EqualTo, 1, x4); ASSERT_EQ(x4, exp1); - x2.applyScalar(nd4j::scalar::EqualTo, 1.5, &x4); + x2.applyScalar(nd4j::scalar::EqualTo, 1.5, x4); ASSERT_EQ(x4, exp1); - x3.applyScalar(nd4j::scalar::EqualTo, true, &x4); + x3.applyScalar(nd4j::scalar::EqualTo, true, x4); ASSERT_EQ(x4, exp2); } @@ -1399,22 +1391,23 @@ TEST_F(MultiDataTypeTests, ndarray_applyLambda_test1) { NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL); - x1.applyLambda(func1, &x4); + x1.applyLambda(func1, x4); ASSERT_EQ(x4, exp1); - x2.applyLambda(func1); + x2.applyLambda(func1, x2); ASSERT_EQ(x2, exp2); - x2.applyLambda(func2); + x2.applyLambda(func2, x2); ASSERT_EQ(x2, exp2); - x3.applyLambda(func3); + x3.applyLambda(func3, x3); ASSERT_EQ(x3, exp3); - x5.applyLambda(func4); + x5.applyLambda(func4, x5); + // x5.printBuffer(); ASSERT_EQ(x5, exp4); - x6.applyLambda(func5, &x7); + x6.applyLambda(func5, x7); ASSERT_EQ(x7, exp5); } @@ -1444,22 +1437,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) { NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL); NDArray exp6('c', {2,2}, {0, 3, 6, 9}, nd4j::DataType::INT64); - x1.applyIndexedLambda(func1, &x4); + x1.applyIndexedLambda(func1, x4); ASSERT_EQ(x4, exp1); - x2.applyIndexedLambda(func1); + x2.applyIndexedLambda(func1, x2); ASSERT_EQ(x2, exp2); - x2.applyIndexedLambda(func2); + x2.applyIndexedLambda(func2, x2); ASSERT_EQ(x2, exp6); - x3.applyIndexedLambda(func3); + x3.applyIndexedLambda(func3, x3); ASSERT_EQ(x3, exp3); - x5.applyIndexedLambda(func4); + x5.applyIndexedLambda(func4, x5); ASSERT_EQ(x5, exp4); - x6.applyIndexedLambda(func5, &x7); + x6.applyIndexedLambda(func5, x7); ASSERT_EQ(x7, exp5); } @@ -1490,22 +1483,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); - x1.applyPairwiseLambda(&other2, func1, &x4); + x1.applyPairwiseLambda(other2, func1, x4); ASSERT_EQ(x4, exp1); - x2.applyPairwiseLambda(&other3, func1); + x2.applyPairwiseLambda(other3, func1, x2); ASSERT_EQ(x2, exp2); - x2.applyPairwiseLambda(&other3, func2); + x2.applyPairwiseLambda(other3, func2, x2); ASSERT_EQ(x2, other3); - x3.applyPairwiseLambda(&other1, func3); + x3.applyPairwiseLambda(other1, func3, x3); ASSERT_EQ(x3, exp3); - x5.applyPairwiseLambda(&other1, func4); + x5.applyPairwiseLambda(other1, func4, x5); ASSERT_EQ(x5, exp4); - x6.applyPairwiseLambda(&other4, func5, &x7); + x6.applyPairwiseLambda(other4, func5, x7); ASSERT_EQ(x7, exp5); } @@ -1536,22 +1529,22 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL); - x1.applyIndexedPairwiseLambda(&other2, func1, &x4); + x1.applyIndexedPairwiseLambda(other2, func1, x4); ASSERT_EQ(x4, exp1); - x2.applyIndexedPairwiseLambda(&other3, func1); + x2.applyIndexedPairwiseLambda(other3, func1, x2); ASSERT_EQ(x2, exp2); - x2.applyIndexedPairwiseLambda(&other3, func2); + x2.applyIndexedPairwiseLambda(other3, func2, x2); ASSERT_EQ(x2, exp2); - x3.applyIndexedPairwiseLambda(&other1, func3); + x3.applyIndexedPairwiseLambda(other1, func3, x3); ASSERT_EQ(x3, exp3); - x5.applyIndexedPairwiseLambda(&other1, func4); + x5.applyIndexedPairwiseLambda(other1, func4, x5); ASSERT_EQ(x5, exp4); - x6.applyIndexedPairwiseLambda(&other4, func5, &x7); + x6.applyIndexedPairwiseLambda(other4, func5, x7); ASSERT_EQ(x7, exp5); } @@ -1578,16 +1571,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { NDArray exp('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL); - x1.applyTriplewiseLambda(&x2, &x3, func1, &x4); + x1.applyTriplewiseLambda(x2, x3, func1, x4); ASSERT_EQ(x4, x2); - x1.applyTriplewiseLambda(&x2, &x3, func2); + x1.applyTriplewiseLambda(x2, x3, func2, x1); ASSERT_EQ(x1, x3); - x5.applyTriplewiseLambda(&x6, &x7, func3); + x5.applyTriplewiseLambda(x6, x7, func3, x5); ASSERT_EQ(x5, x7); - x8.applyTriplewiseLambda(&x9, &x10, func4); + x8.applyTriplewiseLambda(x9, x10, func4, x8); ASSERT_EQ(x8, exp); } @@ -1601,18 +1594,14 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); - NDArray* scalar = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); - ASSERT_EQ(*scalar, exp1); + NDArray scalar = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); + ASSERT_EQ(scalar, exp1); - NDArray* vec1 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); - ASSERT_EQ(*vec1, exp2); + NDArray vec1 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); + ASSERT_EQ(vec1, exp2); - NDArray* vec2 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); - ASSERT_EQ(*vec2, exp3); - - delete scalar; - delete vec1; - delete vec2; + NDArray vec2 = x1.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); + ASSERT_EQ(vec2, exp3); } ////////////////////////////////////////////////////////////////////////////// @@ -1626,13 +1615,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); - x1.applyIndexReduce(nd4j::indexreduce::IndexMax, &scalar, {0,1}); + x1.applyIndexReduce(nd4j::indexreduce::IndexMax, scalar, {0,1}); ASSERT_EQ(scalar, exp1); - x1.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec1, {1}); + x1.applyIndexReduce(nd4j::indexreduce::IndexMax, vec1, {1}); ASSERT_EQ(vec1, exp2); - x1.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec2, {0}); + x1.applyIndexReduce(nd4j::indexreduce::IndexMax, vec2, {0}); ASSERT_EQ(vec2, exp3); } @@ -1646,13 +1635,11 @@ TEST_F(MultiDataTypeTests, applyReduce3_test1) { NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); - auto result = x1.applyReduce3(reduce3::Dot, &x2); - ASSERT_EQ(*result, exp1); - delete result; + auto result = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_EQ(result, exp1); - result = x3.applyReduce3(reduce3::Dot, &x4); - ASSERT_EQ(*result, exp2); - delete result; + result = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// @@ -1674,29 +1661,23 @@ TEST_F(MultiDataTypeTests, applyReduce3_test2) { NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE); NDArray exp6('c', {2}, {9,22.5}, nd4j::DataType::DOUBLE); - auto result = x1.applyReduce3(reduce3::Dot, &x2, {0,1}); - ASSERT_EQ(*result, exp1); - delete result; + auto result = x1.applyReduce3(reduce3::Dot, x2, {0,1}); + ASSERT_EQ(result, exp1); - result = x3.applyReduce3(reduce3::Dot, &x4, {0,1}); - ASSERT_EQ(*result, exp2); - delete result; + result = x3.applyReduce3(reduce3::Dot, x4, {0,1}); + ASSERT_EQ(result, exp2); - result = x5.applyReduce3(reduce3::Dot, &x6, std::vector({0})); - ASSERT_EQ(*result, exp3); - delete result; + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); + ASSERT_EQ(result, exp3); - result = x5.applyReduce3(reduce3::Dot, &x6, std::vector({1})); - ASSERT_EQ(*result, exp4); - delete result; + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); + ASSERT_EQ(result, exp4); - result = x8.applyReduce3(reduce3::Dot, &x7, std::vector({0})); - ASSERT_EQ(*result, exp5); - delete result; + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); + ASSERT_EQ(result, exp5); - result = x8.applyReduce3(reduce3::Dot, &x7, std::vector({1})); - ASSERT_EQ(*result, exp6); - delete result; + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); + ASSERT_EQ(result, exp6); } ////////////////////////////////////////////////////////////////////// @@ -1709,13 +1690,11 @@ TEST_F(MultiDataTypeTests, applyAllReduce3_test1) { NDArray exp1('c', {2,3}, {2,-2,2,2,-2,2}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2,3}, {6,6,6,9,9,9}, nd4j::DataType::DOUBLE); - auto result = x1.applyAllReduce3(reduce3::Dot, &x2, {0}); - ASSERT_EQ(*result, exp1); - delete result; + auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_EQ(result, exp1); - result = x4.applyAllReduce3(reduce3::Dot, &x3, {0}); - ASSERT_EQ(*result, exp2); - delete result; + result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// @@ -1734,16 +1713,16 @@ TEST_F(MultiDataTypeTests, RowCol_test1) { NDArray exp3('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, nd4j::DataType::DOUBLE); NDArray exp4('c', {2,3}, {0,1,1,2,3,3}, nd4j::DataType::INT32); - x1.addiRowVector(&x3); + x1.addiRowVector(x3); ASSERT_EQ(x1, exp1); - x1.addiColumnVector(&x2); + x1.addiColumnVector(x2); ASSERT_EQ(x1, exp1); - x4.addiColumnVector(&x2); + x4.addiColumnVector(x2); ASSERT_EQ(x4, exp3); - x5.muliColumnVector(&x2); + x5.muliColumnVector(x2); ASSERT_EQ(x5, exp4); } @@ -1770,22 +1749,22 @@ TEST_F(MultiDataTypeTests, RowCol_test2) { NDArray exp5('c', {2,3}, {1,1,1,4,2.5,2}, nd4j::DataType::DOUBLE); NDArray exp6('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, nd4j::DataType::FLOAT32); - x1.addRowVector(&x3, &x4); + x1.addRowVector(x3, x4); ASSERT_EQ(x4, exp1); - x1.addRowVector(&x5, &x6); + x1.addRowVector(x5, x6); ASSERT_EQ(x6, exp2); - x8.subRowVector(&x7, &x4); + x8.subRowVector(x7, x4); ASSERT_EQ(x4, exp3); - x1.mulRowVector(&x9, &x10); + x1.mulRowVector(x9, x10); ASSERT_EQ(x10, exp4); - x1.divRowVector(&x9, &x10); + x1.divRowVector(x9, x10); ASSERT_EQ(x10, exp5); - x1.addColumnVector(&x2, &x4); + x1.addColumnVector(x2, x4); ASSERT_EQ(x4, exp6); } @@ -1826,25 +1805,6 @@ TEST_F(MultiDataTypeTests, tile_test1) { } */ -////////////////////////////////////////////////////////////////////// -TEST_F(MultiDataTypeTests, broadcast_test1) { - - NDArray x1('c', {2,1,3}, nd4j::DataType::INT32); - NDArray x2('c', {2,4,1}, nd4j::DataType::INT64); - NDArray x3('c', {2,4,1}, nd4j::DataType::DOUBLE); - - NDArray exp1('c', {2,4,3}, nd4j::DataType::INT32); - NDArray exp2('c', {2,4,3}, nd4j::DataType::DOUBLE); - - auto result = x1.broadcast(x2); - ASSERT_TRUE(result->isSameShapeStrict(&exp1)); - delete result; - - result = x1.broadcast(x3); - ASSERT_TRUE(result->isSameShapeStrict(&exp2)); - delete result; -} - ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, asT_test1) { @@ -1853,19 +1813,19 @@ TEST_F(MultiDataTypeTests, asT_test1) { NDArray exp1('c', {2}, {1, 2}, nd4j::DataType::INT32); NDArray exp2('c', {2}, {1.5, 2.5}, nd4j::DataType::DOUBLE); - auto result = x1.asT(); + auto result = new NDArray(x1.asT()); ASSERT_EQ(*result, exp1); delete result; - result = x1.asT(); + result = new NDArray(x1.asT()); ASSERT_EQ(*result, exp2); delete result; - result = x1.asT(nd4j::DataType::INT32); + result = new NDArray(x1.asT(nd4j::DataType::INT32)); ASSERT_EQ(*result, exp1); delete result; - result = x1.asT(nd4j::DataType::DOUBLE); + result = new NDArray(x1.asT(nd4j::DataType::DOUBLE)); ASSERT_EQ(*result, exp2); delete result; } @@ -1904,7 +1864,7 @@ TEST_F(MultiDataTypeTests, Test_Cast_1) { asBool.assign(first); // asBool.printIndexedBuffer("asBool"); - asBool.applyScalar(scalar::Not, false, &_not); + asBool.applyScalar(scalar::Not, false, _not); // _not.printIndexedBuffer("_not"); @@ -1925,7 +1885,7 @@ TEST_F(MultiDataTypeTests, Test_Cast_2) { asBool.assign(first); // asBool.printIndexedBuffer("asBool"); - asBool.applyTransform(transform::Not, &_not); + asBool.applyTransform(transform::Not, _not); // _not.printIndexedBuffer("_not"); @@ -1968,7 +1928,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) { } try { - x1.divRowVector(&x4, &x3); + x1.divRowVector(x4, x3); } catch (std::exception& message) { // printf("%s\n", message.what()); @@ -1976,7 +1936,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) { } try { - x1.applyBroadcast(nd4j::broadcast::FloorDiv, {1}, &x4, &x3); + x1.applyBroadcast(nd4j::broadcast::FloorDiv, {1}, x4, x3); } catch (std::exception& message) { // printf("%s\n", message.what()); @@ -1984,7 +1944,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) { } try { - x1.applyTrueBroadcast(BROADCAST(FloorMod), &x2, &x3, true); + x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); } catch (std::exception& message) { // printf("%s\n", message.what()); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index 7740cd1ac..3cdd8f70a 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -125,7 +125,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_03) { ASSERT_FALSE(x->isActualOnHostSide()); NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Neg, y, nullptr); + x->applyTransform(transform::Neg, *y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -145,7 +145,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Cosine_1) { ASSERT_FALSE(x->isActualOnHostSide()); NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Cosine, y, nullptr); + x->applyTransform(transform::Cosine, *y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -278,7 +278,7 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_4) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Add, y, z); // // cudaFree(devBufferPtrX); @@ -400,7 +400,7 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_1) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Multiply, y, z); // x.printBuffer("3X = "); // y.printBuffer("3Y = "); // z.printBuffer("3Result out"); @@ -432,7 +432,7 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_2) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Multiply, y, z); // // cudaFree(devBufferPtrX); @@ -461,7 +461,7 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_3) { //ASSERT_EQ(0, res); //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x.applyPairwiseTransform(pairwise::Multiply, y, z); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); // z.printBuffer("23Result out"); @@ -539,7 +539,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveNeg_2) { ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Neg, &y, nullptr); + x.applyTransform(transform::Neg, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -559,7 +559,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Sqrt, &y, nullptr); + x.applyTransform(transform::Sqrt, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -580,7 +580,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict //ASSERT_TRUE(x.isActualOnDeviceSide()); //ASSERT_TRUE(x.isActualOnHostSide()); - x.applyTransform(transform::Assign, &y, nullptr); + x.applyTransform(transform::Assign, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -606,7 +606,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, &y, nullptr); + x.applyTransform(transform::Cosine, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -628,7 +628,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_2) { ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, &y, nullptr); + x.applyTransform(transform::Cosine, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -657,7 +657,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_3) { ASSERT_TRUE(x.isActualOnDeviceSide()); ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, &y, nullptr); + x.applyTransform(transform::Cosine, y); //ASSERT_TRUE(x->isActualOnDeviceSide()); //ASSERT_FALSE(x->isActualOnHostSide()); @@ -857,7 +857,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_01) { //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; // z.printBuffer("53Result out"); // @@ -890,7 +890,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_02) { //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; // z.printBuffer("52Result out"); @@ -924,7 +924,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_002) { //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); //x.printBuffer("23X = "); //y.printBuffer("23Y = "); - x.applyPairwiseTransform(pairwise::Multiply, &y, &z);// *= y; + x.applyPairwiseTransform(pairwise::Multiply, y, z);// *= y; // z.printBuffer("51Result out"); @@ -1059,7 +1059,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_2) { //x.printBuffer("23X = "); //y.printBuffer("23Y = "); //void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &exp); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); // // cudaFree(devBufferPtrX); @@ -1106,10 +1106,7 @@ TEST_F(NDArrayCudaBasicsTests, TestDup1) { ASSERT_TRUE(array.equalsTo(arrF)); ASSERT_TRUE(array.equalsTo(arrC)); - ASSERT_TRUE(arrF->equalsTo(arrC)); - - delete arrC; - delete arrF; + ASSERT_TRUE(arrF.equalsTo(arrC)); } ////////////////////////////////////////////////////////////////////////// @@ -1169,27 +1166,22 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32); - NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + NDArray z = x.applyReduce3(nd4j::reduce3::Dot, y, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyReduce3(nd4j::reduce3::Dot, &k, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + z = x.applyReduce3(nd4j::reduce3::Dot, k, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); x.permutei({0,2,1}); y.permutei({0,2,1}); - z = y.applyReduce3(nd4j::reduce3::Dot, &x, {1}); - ASSERT_TRUE(z->equalsTo(&exp2)); - // printCudaGlobal<<<1,1,0, *y.getContext()->getCudaStream()>>>(z->specialBuffer(), 6); - delete z; + z = y.applyReduce3(nd4j::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); x2.permutei({1,0,2}); - z = x2.applyReduce3(nd4j::reduce3::Dot, &k2, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x2.applyReduce3(nd4j::reduce3::Dot, k2, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// @@ -1206,27 +1198,22 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_2) { NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, nd4j::DataType::DOUBLE); - NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + NDArray z = x.applyReduce3(nd4j::reduce3::Dot, y, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyReduce3(nd4j::reduce3::Dot, &k, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + z = x.applyReduce3(nd4j::reduce3::Dot, k, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); x.permutei({0,2,1}); y.permutei({0,2,1}); - z = y.applyReduce3(nd4j::reduce3::Dot, &x, {1}); - ASSERT_TRUE(z->equalsTo(&exp2)); - // printCudaGlobal<<<1,1,0, *y.getContext()->getCudaStream()>>>(z->specialBuffer(), 6); - delete z; + z = y.applyReduce3(nd4j::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); x2.permutei({1,0,2}); - z = x2.applyReduce3(nd4j::reduce3::Dot, &k2, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x2.applyReduce3(nd4j::reduce3::Dot, k2, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// @@ -1241,26 +1228,22 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) { NDArray exp2('c', {}, {31.5}, nd4j::DataType::DOUBLE); - auto z = x1.applyReduce3(reduce3::Dot, &x2); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + auto z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, &x4); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); x1.permutei({2,1,0}); x2.permutei({2,1,0}); x3.permutei({1,0}); x4.permutei({1,0}); - z = x1.applyReduce3(reduce3::Dot, &x2); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, &x4); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); } //////////////////////////////////////////////////////////////////////////// @@ -1278,37 +1261,28 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE); - auto z = x1.applyAllReduce3(reduce3::Dot, &x2, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x1.applyAllReduce3(reduce3::Dot, &x2, {0}); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_TRUE(z.equalsTo(&exp2)); - z = x3.applyAllReduce3(reduce3::Dot, &x4, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + z = x3.applyAllReduce3(reduce3::Dot, x4, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); - z = x3.applyAllReduce3(reduce3::Dot, &x4, {1}); - // z->syncToHost(); - // z->printShapeInfo(); - // z->printIndexedBuffer(); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); + ASSERT_TRUE(z.equalsTo(&exp4)); x1.permutei({2,1,0}); x2.permutei({2,1,0}); x3.permutei({1,0}); x4.permutei({1,0}); - z = x1.applyAllReduce3(reduce3::Dot, &x2, {0,2}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyAllReduce3(reduce3::Dot, &x4, {0}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); + ASSERT_TRUE(z.equalsTo(&exp4)); } ////////////////////////////////////////////////////////////////////////////// @@ -1328,24 +1302,24 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) { NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &scalar, {0,1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, scalar, {0,1}); ASSERT_TRUE(scalar.equalsTo(&exp1)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec1, {1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec1, {1}); ASSERT_TRUE(vec1.equalsTo(&exp2)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec2, {0}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec2, {0}); ASSERT_TRUE(vec2.equalsTo(&exp3)); x.permutei({1,0}); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &scalar, {0,1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, scalar, {0,1}); ASSERT_TRUE(scalar.equalsTo(&exp4)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec1, {0}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec1, {0}); ASSERT_TRUE(vec1.equalsTo(&exp5)); - x.applyIndexReduce(nd4j::indexreduce::IndexMax, &vec2, {1}); + x.applyIndexReduce(nd4j::indexreduce::IndexMax, vec2, {1}); ASSERT_TRUE(vec2.equalsTo(&exp6)); } @@ -1364,30 +1338,24 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) { NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); auto z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp1)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp1)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z->equalsTo(&exp2)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp2)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z->equalsTo(&exp3)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp3)); x.permutei({1,0}); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z->equalsTo(&exp4)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp4)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z->equalsTo(&exp5)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp5)); z = x.applyIndexReduce(nd4j::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z->equalsTo(&exp6)); - delete z; + ASSERT_TRUE(z.equalsTo(&exp6)); } //////////////////////////////////////////////////////////////////////////////// @@ -1407,24 +1375,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32); - x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::Mean, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::Mean, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::Mean, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Mean, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1439,24 +1407,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) { NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::DOUBLE); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::Mean, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::Mean, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::Mean, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::Mean, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::Mean, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::Mean, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::Mean, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::Mean, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1519,24 +1487,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32); - x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::Sum, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::Sum, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::Sum, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::Sum, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1551,24 +1519,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) { NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {21,5}, nd4j::DataType::INT64); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::Sum, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::Sum, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::Sum, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::Sum, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::Sum, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::Sum, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::Sum, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::Sum, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::Sum, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1589,24 +1557,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL); NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::IsPositive, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::IsPositive, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1621,24 +1589,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::IsPositive, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::IsPositive, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::IsPositive, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1659,24 +1627,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z2, {1}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z2, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z3, {0,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z3, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - x.reduceAlongDimension(nd4j::reduce::CountZero, &z1, {0,1,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z4, {1}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z4, {1}); ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(nd4j::reduce::CountZero, &z5, {0,2}); + x.reduceAlongDimension(nd4j::reduce::CountZero, z5, {0,2}); ASSERT_TRUE(z5.equalsTo(&exp5)); } @@ -1691,24 +1659,24 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) { NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {2,2}, nd4j::DataType::INT64); - NDArray z1 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,1,2}); + NDArray z1 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDims(nd4j::reduce::CountZero, {1}); + NDArray z2 = x.reduceAlongDimension(nd4j::reduce::CountZero, {1}); ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,2}); + NDArray z3 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,2}); ASSERT_TRUE(z3.equalsTo(&exp3)); x.permutei({1,0,2}); // 3x2x2 - NDArray z4 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,1,2}); + NDArray z4 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,1,2}); ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDims(nd4j::reduce::CountZero, {1}); + NDArray z5 = x.reduceAlongDimension(nd4j::reduce::CountZero, {1}); ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDims(nd4j::reduce::CountZero, {0,2}); + NDArray z6 = x.reduceAlongDimension(nd4j::reduce::CountZero, {0,2}); ASSERT_TRUE(z6.equalsTo(&exp5)); } @@ -1722,7 +1690,7 @@ TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest1) { ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, row, &z, nullptr); + x.applyBroadcast(broadcast::Add, {1}, *row, z); x += *row; ASSERT_TRUE(x.equalsTo(z)); @@ -1740,7 +1708,7 @@ TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest2) { NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, nd4j::DataType::FLOAT32); ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, row); + x.applyBroadcast(broadcast::Add, {1}, *row, x); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -1753,7 +1721,7 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcast_1) { auto bias = NDArrayFactory::create('c', {1, 3}); bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, &bias); + input.applyBroadcast(broadcast::Add, {1}, bias, input); ASSERT_TRUE(exp.equalsTo(&input)); } @@ -1807,7 +1775,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) expected = 3.; res2 = 0.f; - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2);// *= y; ASSERT_TRUE(expected.isSameShape(&res2)); ASSERT_TRUE(expected.equalsTo(&res2)); @@ -2095,20 +2063,20 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); auto diag = x.diagonal('c'); - //diag->syncToDevice(); + //diag.syncToDevice(); for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - printf("VAL[%ld] = %f\n", e, diag->e(e)); //, exp.e(e), 1.e-5); + printf("VAL[%ld] = %f\n", e, diag.e(e)); //, exp.e(e), 1.e-5); } for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - ASSERT_NEAR(diag->e(e), exp.e(e), 1.e-5); + ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); } double eps(1.e-5); NDArray tmp(nd4j::DataType::FLOAT32, x.getContext()); // scalar = 0 ExtraArguments extras({eps}); - NativeOpExecutioner::execReduce3Scalar(diag->getContext(), reduce3::EqualsWithEps, diag->getBuffer(), - diag->getShapeInfo(), diag->getSpecialBuffer(), diag->getSpecialShapeInfo(), extras.argumentsAsT(nd4j::DataType::FLOAT32), + NativeOpExecutioner::execReduce3Scalar(diag.getContext(), reduce3::EqualsWithEps, diag.getBuffer(), + diag.getShapeInfo(), diag.getSpecialBuffer(), diag.getSpecialShapeInfo(), extras.argumentsAsT(nd4j::DataType::FLOAT32), exp.getBuffer(), exp.getShapeInfo(), exp.getSpecialBuffer(), exp.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); cudaStream_t* stream = x.getContext()->getCudaStream(); @@ -2116,8 +2084,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { // tmp.printBuffer("Compare result is (expected 0)"); ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp index cc8549e81..e57c7e625 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp @@ -36,7 +36,7 @@ TEST_F(NDArrayListTests, BasicTests_1) { auto x = NDArrayFactory::create('c', {1, 10}); auto y = NDArrayFactory::create('c', {1, 10}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); //ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); } @@ -47,7 +47,7 @@ TEST_F(NDArrayListTests, BasicTests_2) { auto x = NDArrayFactory::create('c', {1, 10}); auto y = NDArrayFactory::create('c', {1, 7}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, &y)); } @@ -63,7 +63,7 @@ TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { ASSERT_EQ(10, list.elements()); - auto array = list.stack(); + auto array = list.stack(); ASSERT_TRUE(input.isSameShape(array)); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index d0fb4bf37..fb55b4484 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -52,8 +52,8 @@ TEST_F(NDArrayTest, TestDup1) { NDArray array(arr1, shape1); - auto arrC = array.dup('c'); - auto arrF = array.dup('f'); + auto arrC = new NDArray(array.dup('c')); + auto arrF = new NDArray(array.dup('f')); ASSERT_TRUE(array.equalsTo(arrF)); ASSERT_TRUE(array.equalsTo(arrC)); @@ -87,8 +87,8 @@ TEST_F(NDArrayTest, NDArrayOrder1) { auto f = new float[4] {1, 3, 2, 4}; auto arrayC = new NDArray(c, cShape); - auto arrayF = arrayC->dup('f'); - auto arrayC2 = arrayF->dup('c'); + auto arrayF = new NDArray(arrayC->dup('f')); + auto arrayC2 = new NDArray(arrayF->dup('c')); ASSERT_EQ('c', arrayC->ordering()); ASSERT_EQ('f', arrayF->ordering()); @@ -128,7 +128,7 @@ TEST_F(NDArrayTest, TestGetScalar1) { ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); - auto arrayF = arrayC->dup('f'); + auto arrayF = new NDArray(arrayC->dup('f')); ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); @@ -199,14 +199,12 @@ TEST_F(NDArrayTest, TestTad1) { auto row2 = array->tensorAlongDimension(1, {1}); - ASSERT_TRUE(row2->isView()); - ASSERT_EQ(3, row2->lengthOf()); + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); - row2->assign(1.0); + row2.assign(1.0); ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); - - delete row2; delete array; } @@ -225,18 +223,8 @@ TEST_F(NDArrayTest, TestTad3) { auto row2 = array->tensorAlongDimension(1, {1}); - ASSERT_TRUE(row2->isView()); - ASSERT_EQ(3, row2->lengthOf()); - - row2->p(1, 1.0); - - //array->printBuffer(); - - row2->p(2, 1.0); - - //array->printBuffer(); - - delete row2; + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); delete array; } @@ -296,17 +284,14 @@ TEST_F(NDArrayTest, TestRepeat1) { auto rep = array.repeat(0, {2}); - ASSERT_EQ(4, rep->sizeAt(0)); - ASSERT_EQ(2, rep->sizeAt(1)); - - // rep->printIndexedBuffer("Repeated"); + ASSERT_EQ(4, rep.sizeAt(0)); + ASSERT_EQ(2, rep.sizeAt(1)); ASSERT_TRUE(exp->equalsTo(rep)); delete[] eBuffer; delete[] eShape; delete exp; - delete rep; } ////////////////////////////////////////////////////////////////////// @@ -320,7 +305,7 @@ TEST_F(NDArrayTest, TestRepeat2) { //array->printBuffer(); - auto rep = exp->dup(); + auto rep = new NDArray(exp->dup()); rep->assign(0.); array->repeat(0, {2}, *rep); //rep->printIndexedBuffer("Repeated"); @@ -374,7 +359,7 @@ TEST_F(NDArrayTest, TestAddiRowVector) { auto exp = new NDArray(e, cShape); row->assign(1.0f); - array->addiRowVector(row); + array->addiRowVector(*row); ASSERT_TRUE(exp->equalsTo(array)); @@ -397,8 +382,8 @@ TEST_F(NDArrayTest, TestAddiColumnVector) { NDArray column(arr2, shape2); NDArray exp(arr3, shape1); - matrix.addiColumnVector(&column); - ASSERT_TRUE(exp.isSameShapeStrict(&matrix)); + matrix.addiColumnVector(column); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); ASSERT_TRUE(exp.equalsTo(&matrix)); } @@ -414,9 +399,9 @@ TEST_F(NDArrayTest, TestMuliColumnVector) { NDArray column(arr2, shape2); NDArray exp(arr3, shape1); - matrix.muliColumnVector(&column); + matrix.muliColumnVector(column); - ASSERT_TRUE(exp.isSameShapeStrict(&matrix)); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); ASSERT_TRUE(exp.equalsTo(&matrix)); } @@ -478,7 +463,7 @@ TEST_F(NDArrayTest, TestSumAlongDimension1) { NDArray array('c', {2,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); - auto res = array.reduceAlongDims(reduce::Sum, {0}); + auto res = array.reduceAlongDimension(reduce::Sum, {0}); ASSERT_EQ(2, res.lengthOf()); @@ -493,14 +478,13 @@ TEST_F(NDArrayTest, TestSumAlongDimension2) { auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res->lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res->e(0)); - ASSERT_EQ(7.0f, res->e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); delete[] c; delete array; - delete res; } ////////////////////////////////////////////////////////////////////// @@ -508,18 +492,15 @@ TEST_F(NDArrayTest, TestReduceAlongDimension1) { float *c = new float[4] {1, 2, 3, 4}; auto array = new NDArray(c, cShape); - auto exp = array->reduceAlongDimension(reduce::Sum, {1}); auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res->lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res->e(0)); - ASSERT_EQ(7.0f, res->e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); delete[] c; delete array; - delete exp; - delete res; } ////////////////////////////////////////////////////////////////////// @@ -530,7 +511,7 @@ TEST_F(NDArrayTest, TestTransform1) { float *e = new float[4] {1, 2, 3, 4}; auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, nullptr, nullptr); + array->applyTransform(transform::Abs, *array); ASSERT_TRUE(exp->equalsTo(array)); @@ -579,7 +560,7 @@ TEST_F(NDArrayTest, TestApplyTransform1) { float *e = new float[4] {1, 2, 3, 4}; auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, nullptr, nullptr); + array->applyTransform(transform::Abs, *array); ASSERT_TRUE(exp->equalsTo(array)); @@ -668,20 +649,17 @@ TEST_F(NDArrayTest, TestReductionAny1) { array.syncToDevice(); auto result0 = array.reduceAlongDimension(reduce::Any, {0}); - ASSERT_EQ(2, result0->lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); - ASSERT_NEAR(1.0f, result0->e(0), 1e-5f); - ASSERT_NEAR(1.0f, result0->e(1), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); auto result1 = array.reduceAlongDimension(reduce::Any, {1}); - ASSERT_EQ(2, result1->lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_NEAR(1.0f, result1->e(0), 1e-5f); - ASSERT_NEAR(0.0f, result1->e(1), 1e-5f); - - delete result0; - delete result1; + ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); + ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); } TEST_F(NDArrayTest, TestReductionAll1) { @@ -694,17 +672,14 @@ TEST_F(NDArrayTest, TestReductionAll1) { auto result0 = array.reduceAlongDimension(reduce::All, {0}); auto result1 = array.reduceAlongDimension(reduce::All, {1}); - ASSERT_EQ(2, result0->lengthOf()); - ASSERT_EQ(2, result1->lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_FALSE(result0->e(0)); - ASSERT_FALSE(result0->e(1)); + ASSERT_FALSE(result0.e(0)); + ASSERT_FALSE(result0.e(1)); - ASSERT_TRUE(result1->e(0)); - ASSERT_FALSE(result1->e(1)); - - delete result0; - delete result1; + ASSERT_TRUE(result1.e(0)); + ASSERT_FALSE(result1.e(1)); } ////////////////////////////////////////////////////////////////////// @@ -728,7 +703,7 @@ TEST_F(NDArrayTest, TestTile1) { NDArray array1(arr1,shape1); // {2,3} NDArray array2(arr2,shape2); // {2,4,6} - auto expA = array1.dup('c'); + auto expA = new NDArray(array1.dup('c')); auto tiled = array1.tile(tileShape1); @@ -766,7 +741,7 @@ TEST_F(NDArrayTest, TestTile3) { array1.tilei(tileShape1); - ASSERT_TRUE(array1.isSameShapeStrict(&array2)); + ASSERT_TRUE(array1.isSameShapeStrict(array2)); ASSERT_TRUE(array1.equalsTo(&array2)); } @@ -781,7 +756,7 @@ TEST_F(NDArrayTest, TestTile4) { auto result = x.tile({2,1}); - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -796,7 +771,7 @@ TEST_F(NDArrayTest, TestTile5) { auto result = x.tile({2,1}); - ASSERT_TRUE(result.isSameShapeStrict(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); ASSERT_TRUE(result.equalsTo(&exp)); } @@ -881,8 +856,8 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { for (int e = 0; e < y.lengthOf(); e++) y.p(e, e+1); - auto x_ = x.dup('f'); - auto y_ = y.dup('f'); + auto x_ = new NDArray(x.dup('f')); + auto y_ = new NDArray(y.dup('f')); x_->permutei({1, 0}); y_->permutei({1, 0}); @@ -940,7 +915,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { for (int e = 0; e < y.lengthOf(); e++) y.p(e, e+1); - auto y_ = y.dup('f'); + auto y_ = new NDArray(y.dup('f')); x.permutei({0, 3, 4, 5, 1, 2}); y_->permutei({3, 2, 1, 0}); @@ -1264,7 +1239,7 @@ TEST_F(NDArrayTest, Permute1) { NDArray arr2(shape2,true); auto result = arr1.permute(perm); - ASSERT_TRUE(result.isSameShapeStrict(&arr2)); + ASSERT_TRUE(result.isSameShapeStrict(arr2)); } ////////////////////////////////////////////////////////////////////// @@ -1279,33 +1254,16 @@ TEST_F(NDArrayTest, Permute2) { NDArray arr2(shape2,true); ASSERT_TRUE(arr1.permutei(perm)); - ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); } -////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Broadcast1) { - - Nd4jLong shape1[10] = {3, 5, 1, 10, 10, 10, 1, 8192, 1, 99}; - Nd4jLong shape2[8] = {2, 7, 10, 10, 1, 8192, 1, 99}; - Nd4jLong shape3[10] = {3, 5, 7, 10, 70, 10, 1, 8192, 1, 99}; - - NDArray arr1(shape1); - NDArray arr2(shape2); - NDArray arr3(shape3); - - auto result = arr1.broadcast(arr2); - ASSERT_TRUE(result->isSameShapeStrict(&arr3)); - delete result; -} - - TEST_F(NDArrayTest, RSubScalarTest1) { auto array = NDArrayFactory::create('c', {1, 4}); array.assign(2.0); auto result = NDArrayFactory::create('c', {1, 4}); - array.applyScalar(scalar::ReverseSubtract, 1.0, &result); + array.applyScalar(scalar::ReverseSubtract, 1.0, result); ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); } @@ -1324,7 +1282,7 @@ TEST_F(NDArrayTest, BroadcastOpsTest1) { ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, row); + x.applyBroadcast(broadcast::Add, {1}, *row, x); //x.printBuffer("Result"); @@ -1374,9 +1332,9 @@ TEST_F(NDArrayTest, TestIndexedPut5) { TEST_F(NDArrayTest, TestAllTensors1) { auto matrix = NDArrayFactory::create('c', {3, 5}); - std::unique_ptr rows(matrix.allTensorsAlongDimension({1})); + ResultSet rows = matrix.allTensorsAlongDimension({1}); - ASSERT_EQ(3, rows->size()); + ASSERT_EQ(3, rows.size()); } @@ -1573,17 +1531,15 @@ TEST_F(NDArrayTest, TestStdDev2) { auto array = NDArrayFactory::create('c', {5, 6}); auto tad = array.tensorAlongDimension(0, {0}); - ASSERT_EQ(5, tad->lengthOf()); + ASSERT_EQ(5, tad.lengthOf()); - for (int e = 0; e < tad->lengthOf(); e++) - tad->p(e, e+1); + for (int e = 0; e < tad.lengthOf(); e++) + tad.p(e, e+1); - ASSERT_NEAR(15, tad->sumNumber().e(0), 1e-5); + ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); - auto std = tad->varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(std, 1.58109, 1e-4); - - delete tad; } TEST_F(NDArrayTest, TestStdDev3) { @@ -1654,8 +1610,6 @@ TEST_F(NDArrayTest, TestApplyIndexReduce1) { auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// @@ -1667,11 +1621,9 @@ TEST_F(NDArrayTest, applyReduce3Dot) { NDArray x(xBuff, xShapeInfo); NDArray y(yBuff, xShapeInfo); - auto result = x.applyReduce3(reduce3::Dot, &y); - ASSERT_TRUE(result->lengthOf() == 1); - ASSERT_NEAR(42, result->e(0), 1e-5); - - delete result; + auto result = x.applyReduce3(reduce3::Dot, y); + ASSERT_TRUE(result.lengthOf() == 1); + ASSERT_NEAR(42, result.e(0), 1e-5); } ////////////////////////////////////////////////////////////////////// @@ -1686,17 +1638,12 @@ TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { NDArray y(yBuff, xShapeInfo); auto exp = NDArrayFactory::create('c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, &y,{1}); - - // result->printIndexedBuffer("result"); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { float xBuff[] = {1, 2, 3, 4, 5, 6}; @@ -1709,12 +1656,10 @@ TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { NDArray y(yBuff, xShapeInfo); NDArray exp(expBuff, expShapeInfo); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, &y,{1}); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y ,{1}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } @@ -1733,8 +1678,6 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension1) { ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// @@ -1751,8 +1694,6 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension2) { auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension3) { @@ -1765,9 +1706,8 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension3) { auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } + ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension4) { @@ -1779,8 +1719,6 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension4) { auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); - - delete result; } ////////////////////////////////////////////////////////////////////// @@ -1796,9 +1734,9 @@ TEST_F(NDArrayTest, TestSubRowVector1) { NDArray target(x); NDArray exp(expBuff, xShapeInfo); - x.subRowVector(&y,&target); + x.subRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); ASSERT_TRUE(exp.equalsTo(&target)); } @@ -1815,9 +1753,9 @@ TEST_F(NDArrayTest, TestDivRowVector1) { NDArray target(x); NDArray exp(expBuff, xShapeInfo); - x.divRowVector(&y,&target); + x.divRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); ASSERT_TRUE(exp.equalsTo(&target)); } @@ -1834,9 +1772,9 @@ TEST_F(NDArrayTest, TestMulRowVector1) { NDArray target(x); NDArray exp(expBuff, xShapeInfo); - x.mulRowVector(&y,&target); + x.mulRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); ASSERT_TRUE(exp.equalsTo(&target)); } @@ -1895,7 +1833,7 @@ TEST_F(NDArrayTest, TestBroadcast_1) { bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, &bias); + input.applyBroadcast(broadcast::Add, {1}, bias, input); //input.printBuffer("result"); ASSERT_TRUE(exp.equalsTo(&input)); @@ -2457,7 +2395,7 @@ TEST_F(NDArrayTest, Test_Lambda_1) { return _val + 3.0f; }; - x.applyLambda(lambda); + x.applyLambda(lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -2472,7 +2410,7 @@ TEST_F(NDArrayTest, Test_Lambda_2) { return _x + _y + 1.0f; }; - x.applyPairwiseLambda(&y, lambda); + x.applyPairwiseLambda(y, lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -2487,7 +2425,7 @@ TEST_F(NDArrayTest, Test_Lambda_3) { return (_x + _y) * 2; }; - x.applyPairwiseLambda(&y, lambda); + x.applyPairwiseLambda(y, lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -2518,8 +2456,6 @@ TEST_F(NDArrayTest, Test_diagonal_1) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2533,8 +2469,6 @@ TEST_F(NDArrayTest, Test_diagonal_2) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2548,8 +2482,6 @@ TEST_F(NDArrayTest, Test_diagonal_3) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2563,8 +2495,6 @@ TEST_F(NDArrayTest, Test_diagonal_4) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2578,8 +2508,6 @@ TEST_F(NDArrayTest, Test_diagonal_5) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2593,8 +2521,6 @@ TEST_F(NDArrayTest, Test_diagonal_6) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2608,8 +2534,6 @@ TEST_F(NDArrayTest, Test_diagonal_7) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2623,8 +2547,6 @@ TEST_F(NDArrayTest, Test_diagonal_8) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2638,8 +2560,6 @@ TEST_F(NDArrayTest, Test_diagonal_9) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } @@ -2654,8 +2574,6 @@ TEST_F(NDArrayTest, Test_diagonal_10) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2669,8 +2587,6 @@ TEST_F(NDArrayTest, Test_diagonal_11) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2684,8 +2600,6 @@ TEST_F(NDArrayTest, Test_diagonal_12) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } //////////////////////////////////////////////////////////////////// @@ -2699,8 +2613,6 @@ TEST_F(NDArrayTest, Test_diagonal_13) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } //////////////////////////////////////////////////////////////////// @@ -2714,8 +2626,6 @@ TEST_F(NDArrayTest, Test_diagonal_14) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2729,8 +2639,6 @@ TEST_F(NDArrayTest, Test_diagonal_15) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2744,8 +2652,6 @@ TEST_F(NDArrayTest, Test_diagonal_16) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2759,8 +2665,6 @@ TEST_F(NDArrayTest, Test_diagonal_17) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// @@ -2774,8 +2678,6 @@ TEST_F(NDArrayTest, Test_diagonal_18) { ASSERT_TRUE(exp.isSameShape(diag)); ASSERT_TRUE(exp.equalsTo(diag)); - - delete diag; } ////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp index 4f8d38e76..4507086f5 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -196,12 +196,10 @@ TEST_F(NDArrayTest2, Test_AllReduce3_1) { auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - delete z; } //////////////////////////////////////////////////////////////////// @@ -210,12 +208,10 @@ TEST_F(NDArrayTest2, Test_AllReduce3_2) { auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - delete z; } //////////////////////////////////////////////////////////////////// @@ -278,7 +274,7 @@ TEST_F(NDArrayTest2, Test_Streamline_1) { ASSERT_TRUE(x.isSameShape(&y)); ASSERT_TRUE(x.equalsTo(&y)); - ASSERT_FALSE(x.isSameShapeStrict(&y)); + ASSERT_FALSE(x.isSameShapeStrict(y)); } @@ -306,7 +302,7 @@ TEST_F(NDArrayTest2, Test_Enforce_1) { x.enforce({4, 4}, 'c'); - ASSERT_TRUE(exp.isSameShapeStrict(&x)); + ASSERT_TRUE(exp.isSameShapeStrict(x)); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -315,7 +311,7 @@ TEST_F(NDArrayTest2, TestVector_1) { auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - x.addiRowVector(&row); + x.addiRowVector(row); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -359,7 +355,7 @@ TEST_F(NDArrayTest2, tileToShape_test1) { auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - x.tileToShape({2,2,2}); + x.tileToShape({2,2,2}, x); ASSERT_TRUE(x.isSameShape(&exp)); ASSERT_TRUE(x.equalsTo(&exp)); @@ -371,7 +367,7 @@ TEST_F(NDArrayTest2, tileToShape_test2) { auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); - x.tileToShape({2,3,2}); + x.tileToShape({2,3,2}, x); ASSERT_TRUE(x.isSameShape(&exp)); ASSERT_TRUE(x.equalsTo(&exp)); @@ -384,7 +380,7 @@ TEST_F(NDArrayTest2, tileToShape_test3) { auto result = NDArrayFactory::create('c', {2, 2, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - x.tileToShape({2,2,2}, &result); + x.tileToShape({2,2,2}, result); // result.printIndexedBuffer(); ASSERT_TRUE(result.isSameShape(&exp)); @@ -398,7 +394,7 @@ TEST_F(NDArrayTest2, tileToShape_test4) { auto result = NDArrayFactory::create('c', {2, 3, 2}); auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); - x.tileToShape({2,3,2}, &result); + x.tileToShape({2,3,2}, result); ASSERT_TRUE(result.isSameShape(&exp)); ASSERT_TRUE(result.equalsTo(&exp)); @@ -418,7 +414,7 @@ TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(&u, &v, la); + t.applyTriplewiseLambda(u, v, la, t); ASSERT_TRUE(t.equalsTo(&exp)); } @@ -436,7 +432,7 @@ TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(&u, &v, la); + t.applyTriplewiseLambda(u, v, la, t); ASSERT_TRUE(t.equalsTo(&exp)); } @@ -450,7 +446,7 @@ TEST_F(NDArrayTest2, Test_Indexed_Lambda) { return (float) _idx; }; - x.applyIndexedLambda(lambda); + x.applyIndexedLambda(lambda, x); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -565,7 +561,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test1) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); - x.fillAsTriangular(0., 0, 0, 'u'); + x.fillAsTriangular(0., 0, 0, x, 'u'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -578,7 +574,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test2) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); - x.fillAsTriangular(0., 0, -1, 'u'); + x.fillAsTriangular(0., 0, -1, x, 'u'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -591,7 +587,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test3) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); - x.fillAsTriangular(0., 0, 0, 'l'); + x.fillAsTriangular(0., 0, 0, x, 'l'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -604,7 +600,7 @@ TEST_F(NDArrayTest2, fillAsTriangular_test4) { auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); - x.fillAsTriangular(0., 1, 0, 'l'); + x.fillAsTriangular(0., 1, 0, x, 'l'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -616,13 +612,10 @@ TEST_F(NDArrayTest2, Test_DType_Conversion_1) { auto xd = x.template asT(); - auto xf = xd->template asT(); + auto xf = xd.template asT(); ASSERT_TRUE(x.isSameShape(xf)); ASSERT_TRUE(x.equalsTo(xf)); - - delete xf; - delete xd; } //////////////////////////////////////////////////////////////////// @@ -677,7 +670,7 @@ TEST_F(NDArrayTest2, permute_test4) { // arr1P->printShapeInfo(); // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); - ASSERT_TRUE(arr1P.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); delete []arr1Buffer; delete []arr2Buffer; } @@ -773,11 +766,9 @@ TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { // set->at(0)->printShapeInfo(); // set->at(0)->printIndexedBuffer(); - ASSERT_TRUE(set->size() == 1); - ASSERT_TRUE(exp.isSameShape(set->at(0))); - ASSERT_TRUE(exp.equalsTo(set->at(0))); - - delete set; + ASSERT_TRUE(set.size() == 1); + ASSERT_TRUE(exp.isSameShape(set.at(0))); + ASSERT_TRUE(exp.equalsTo(set.at(0))); } //////////////////////////////////////////////////////////////////// @@ -838,7 +829,7 @@ TEST_F(NDArrayTest2, scalar_set_test2) { TEST_F(NDArrayTest2, big_dup_test) { // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); - auto dup = arr->dup('c'); + auto dup = new NDArray(arr->dup('c')); ASSERT_EQ(*arr, *dup); @@ -920,8 +911,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_1) { NDArray x('c', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - ASSERT_EQ(5, subArr1->ews()); - delete subArr1; + ASSERT_EQ(5, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -930,8 +920,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_2) { NDArray x('f', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - ASSERT_EQ(1, subArr1->ews()); - delete subArr1; + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -940,8 +929,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_3) { NDArray x('c', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - ASSERT_EQ(1, subArr1->ews()); - delete subArr1; + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -950,8 +938,7 @@ TEST_F(NDArrayTest2, test_subarray_ews_4) { NDArray x('f', {10, 5}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - ASSERT_EQ(10, subArr1->ews()); - delete subArr1; + ASSERT_EQ(10, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// @@ -1065,9 +1052,8 @@ TEST_F(NDArrayTest2, test_subarray_interval_1) { NDArray x('f', {10, 10}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - ASSERT_EQ(10, subArr1->sizeAt(0)); - ASSERT_EQ(9, subArr1->sizeAt(1)); - delete subArr1; + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_interval_2) { @@ -1075,9 +1061,8 @@ TEST_F(NDArrayTest2, test_subarray_interval_2) { NDArray x('c', {10, 10}, nd4j::DataType::FLOAT32); auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - ASSERT_EQ(10, subArr1->sizeAt(0)); - ASSERT_EQ(9, subArr1->sizeAt(1)); - delete subArr1; + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_3d_cf) { @@ -1117,7 +1102,7 @@ TEST_F(NDArrayTest2, test_broadcast_column_2) { auto e = NDArrayFactory::create('c', {5, 10}); e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &x, false); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); ASSERT_EQ(e, x); } @@ -1128,7 +1113,7 @@ TEST_F(NDArrayTest2, test_broadcast_column_3) { auto e = NDArrayFactory::create('c', {5, 10}); e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); ASSERT_EQ(e, x); } @@ -1139,7 +1124,7 @@ TEST_F(NDArrayTest2, test_broadcast_column_4) { auto e = NDArrayFactory::create('f', {10, 5}); e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); ASSERT_EQ(e, x); } @@ -1171,7 +1156,7 @@ TEST_F(NDArrayTest2, test_not_tiled_2) { TEST_F(NDArrayTest2, test_long_sum_1) { auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto z = x.reduceAlongDims(reduce::Sum, {0}); + auto z = x.reduceAlongDimension(reduce::Sum, {0}); } ////////////////////////////////////////////////////////////////////// @@ -1216,7 +1201,7 @@ TEST_F(NDArrayTest2, trueBroadcast_1) { NDArray z('c', {2, 3}, nd4j::DataType::DOUBLE); auto exp = x - y; - x.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), &y, &z, true); + x.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Subtract(), y, z); // exp.printIndexedBuffer(); // z.printIndexedBuffer(); @@ -1232,7 +1217,7 @@ TEST_F(NDArrayTest2, reduce_1) { arr6.linspace(1); - NDArray* arr6s = arr6.reduceAlongDimension(nd4j::reduce::Sum, {2,3}); + NDArray arr6s = arr6.reduceAlongDimension(nd4j::reduce::Sum, {2,3}); for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { @@ -1254,8 +1239,6 @@ TEST_F(NDArrayTest2, reduce_1) { // arr6s->printIndexedBuffer(); ASSERT_TRUE(exp.equalsTo(arr6s)); - - delete arr6s; } ////////////////////////////////////////////////////////////////////// @@ -1265,23 +1248,17 @@ TEST_F(NDArrayTest2, reduce3_1) { NDArray y('c', {1,4}, {2,3,4,5}); NDArray exp('c', {4}, {1,1,1,1}); - NDArray* z = x.applyReduce3(nd4j::reduce3::EuclideanDistance, &y, {0}, nullptr); - // z->printShapeInfo(); - // z->printIndexedBuffer(); + NDArray z = x.applyReduce3(nd4j::reduce3::EuclideanDistance, y, {0}, nullptr); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - delete z; } TEST_F(NDArrayTest2, all_tads_1) { auto x = NDArrayFactory::create('c', {3, 5}); auto arrays = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, arrays->size()); - - delete arrays; + ASSERT_EQ(3, arrays.size()); } TEST_F(NDArrayTest2, test_trueBroadcast_empty_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 6d58e6e41..edef394f3 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -92,9 +92,9 @@ TEST_F(ParityOpsTests, TestMinimum1) { TEST_F(ParityOpsTests, TestTear1) { auto input = NDArrayFactory::create('c', {10, 5}); auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_EQ(5, tads->at(e)->lengthOf()); - tads->at(e)->assign((float) e + 1); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); } nd4j::ops::tear op; @@ -104,18 +104,17 @@ TEST_F(ParityOpsTests, TestTear1) { ASSERT_EQ(10, result->size()); for (int e = 0; e < result->size(); e++) - ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); + ASSERT_TRUE(tads.at(e)->equalsTo(result->at(e))); delete result; - delete tads; } TEST_F(ParityOpsTests, TestUnstack1) { auto input = NDArrayFactory::create('c', {10, 5}); auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_EQ(5, tads->at(e)->lengthOf()); - tads->at(e)->assign((float) e + 1); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); } nd4j::ops::unstack op; @@ -124,14 +123,10 @@ TEST_F(ParityOpsTests, TestUnstack1) { ASSERT_EQ(10, result->size()); - // result->at(0)->printShapeInfo("rz"); - // tads->at(0)->printShapeInfo("re"); - for (int e = 0; e < result->size(); e++) - ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); + ASSERT_TRUE(tads.at(e)->equalsTo(result->at(e))); delete result; - delete tads; } @@ -139,9 +134,9 @@ TEST_F(ParityOpsTests, TestUnstack1) { TEST_F(ParityOpsTests, TestUnstack2) { auto input = NDArrayFactory::create('c', {5,2,6}); auto tads = input.allTensorsAlongDimension({0,1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_EQ(10, tads->at(e)->lengthOf()); - tads->at(e)->assign((float) e + 1); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(10, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); } nd4j::ops::unstack op; @@ -151,10 +146,9 @@ TEST_F(ParityOpsTests, TestUnstack2) { ASSERT_EQ(6, result->size()); for (int e = 0; e < result->size(); e++) - ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); + ASSERT_TRUE(tads.at(e)->equalsTo(result->at(e))); delete result; - delete tads; } TEST_F(ParityOpsTests, TestUnstack3) { @@ -689,11 +683,10 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { auto z = result->at(0); auto tads = z->allTensorsAlongDimension({1}); - for (int e = 0; e < tads->size(); e++) { - ASSERT_TRUE(bias.equalsTo(tads->at(e))); + for (int e = 0; e < tads.size(); e++) { + ASSERT_TRUE(bias.equalsTo(tads.at(e))); } - delete tads; delete result; } @@ -833,7 +826,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_8) { // z.printBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(&z)); + ASSERT_TRUE(expected.isSameShapeStrict(z)); ASSERT_TRUE(expected.equalsTo(z)); } diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 122f25273..d35d736ed 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -376,4 +376,13 @@ TEST_F(PlaygroundTests, my) { delete variableSpace; } -*/ \ No newline at end of file +*/ + +TEST_F(PlaygroundTests, my) { + + NDArray a('c',{2,3,4}, nd4j::DataType::DOUBLE); + a({0,0, 0,1, 0,1}).printShapeInfo(); + a({0,1, 0,0, 0,1}).printShapeInfo(); + a({0,0, 0,1, 0,1}).printShapeInfo(); + +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 5c3ca340b..1072b9dab 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -275,8 +275,8 @@ TEST_F(RNGTests, Test_Gaussian_21) { #ifdef DEBUG_BUILD TEST_F(RNGTests, Test_Gaussian_22) { - auto x0 = NDArrayFactory::create('c', {10000, 1000}); - auto x1 = NDArrayFactory::create('c', {10000, 1000}); + auto x0 = NDArrayFactory::create('c', {1000, 800}); + auto x1 = NDArrayFactory::create('c', {1000, 800}); RandomLauncher::fillGaussian(nd4j::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); @@ -304,7 +304,7 @@ TEST_F(RNGTests, Test_Gaussian_22) { } TEST_F(RNGTests, Test_Gaussian_3) { - auto x0 = NDArrayFactory::create('c', {10000000}); + auto x0 = NDArrayFactory::create('c', {800000}); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0); @@ -381,8 +381,8 @@ TEST_F(RNGTests, Test_Truncated_2) { } TEST_F(RNGTests, Test_Truncated_21) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); @@ -428,8 +428,8 @@ TEST_F(RNGTests, Test_Truncated_21) { } TEST_F(RNGTests, Test_Truncated_22) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f); @@ -522,27 +522,20 @@ TEST_F(RNGTests, Test_Truncated_23) { } TEST_F(RNGTests, Test_Truncated_3) { - auto x0 = NDArrayFactory::create('c', {10000, 1000}); - auto x1 = NDArrayFactory::create('c', {10000, 1000}); + auto x0 = NDArrayFactory::create('c', {2000, 2000}); + auto x1 = NDArrayFactory::create('c', {2000, 2000}); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); ASSERT_TRUE(x0.equalsTo(&x1)); - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - // Check up distribution auto mean = x1.reduceNumber(reduce::Mean); // mean.printIndexedBuffer("Mean 1.0"); //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - //x1.printIndexedBuffer("Distribution TN"); ASSERT_NEAR(mean.e(0), 1.f, 0.001); ASSERT_NEAR(deviation.e(0), 2.f, 0.3); } diff --git a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp index 3cdca2db6..404b95013 100644 --- a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp @@ -35,17 +35,15 @@ TEST_F(ResultSetTests, basic_test_1) { auto x = NDArrayFactory::create('c', {3, 5}); auto tensors = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, tensors->size()); + ASSERT_EQ(3, tensors.size()); - ResultSet set = *tensors; - ASSERT_EQ(3, tensors->size()); + ResultSet set = tensors; + ASSERT_EQ(3, tensors.size()); ASSERT_EQ(3, set.size()); for (int e = 0; e < set.size(); e++) ASSERT_EQ(5, set.at(e)->lengthOf()); - for (int e = 0; e < tensors->size(); e++) - ASSERT_EQ(5, tensors->at(e)->lengthOf()); - - delete tensors; + for (int e = 0; e < tensors.size(); e++) + ASSERT_EQ(5, tensors.at(e)->lengthOf()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp index 3762d790c..41f8ed2d0 100644 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp @@ -74,7 +74,7 @@ TEST_F(SessionLocalTests, BasicTests_2) { auto varSpace = storage.localVariableSpace(); auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(nd4j::scalar::Add, (float) e+1); + arr->applyScalar(nd4j::scalar::Add, (float) e+1, *arr); } float lastValue = 0.0f; diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index 2ae236210..9f9569b92 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -81,7 +81,7 @@ TEST_F(StringTests, Basic_dup_1) { ASSERT_EQ(1, array.lengthOf()); ASSERT_EQ(0, array.rankOf()); - auto dup = array.dup(); + auto dup = new NDArray(array.dup()); auto z0 = array.e(0); auto z1 = dup->e(0); diff --git a/libnd4j/tests_cpu/layers_tests/TadTests.cpp b/libnd4j/tests_cpu/layers_tests/TadTests.cpp index b4a631a8c..86e7264e8 100644 --- a/libnd4j/tests_cpu/layers_tests/TadTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TadTests.cpp @@ -106,7 +106,7 @@ TEST_F(TadTests, TestShapeTad_1) { NDArray tadArr(tadBuff, tadShapeInfo); ASSERT_TRUE(numTads==1); - ASSERT_TRUE(input.isSameShapeStrict(&tadArr)); + ASSERT_TRUE(input.isSameShapeStrict(tadArr)); ASSERT_TRUE(input.equalsTo(&tadArr)); delete[] tadShapeInfo; @@ -133,24 +133,16 @@ TEST_F(TadTests, TadEdgeCase_1) { auto tad = array.tensorAlongDimension(0, {0, 1}); ASSERT_TRUE(exp.isSameShape(tad)); - - delete tad; } TEST_F(TadTests, TestEdgeCase_2) { - auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); - auto tad1 = array.tensorAlongDimension(1, {2}); + auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); for (int e = 0 ; e < array.lengthOf(); e++) { auto tad = array.tensorAlongDimension(e, {2}); - - ASSERT_NEAR(tad->e(0), array.e(e), 1e-5); - - delete tad; + ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); } - - delete tad1; } TEST_F(TadTests, TadEdgeCase_2) { @@ -158,10 +150,7 @@ TEST_F(TadTests, TadEdgeCase_2) { auto tad = array.tensorAlongDimension(0, {1}); - // tad->printShapeInfo("TAD shape"); - ASSERT_EQ(3, tad->lengthOf()); - - delete tad; + ASSERT_EQ(3, tad.lengthOf()); } diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index fcdd1db3c..1a1915fdc 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -166,7 +166,6 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) { ASSERT_TRUE(floating.equalsTo(conv)); delete rv; - delete conv; } /*