From 0cf4a45573d79d846cb2053664c981223a10ed80 Mon Sep 17 00:00:00 2001 From: Adam Gibson <1144306+agibsonccc@users.noreply.github.com> Date: Thu, 19 Mar 2020 14:53:21 +0900 Subject: [PATCH 01/17] Fixes #8763 (#310) * Fix cmake detection in msys * Fix toolchain file on windows * Make android 64 bit work * Fix libnd4j build script on msys * Update build script for windows/linux * Encoding issue for ci * Update pom.xml * Update pom.xml * Update pom.xml * Remove mingw * Ensure android x86 builds are inline with arm builds * Update toolchains and env variables for x86 * Move profile for build program up to parent * Fix blas vendor and add comment * Update cuda presets version * Set default value and move properties back to child pom * Change program from hard coded to use the script as the program * Update pom.xml * Update pom.xml * Static lib fix * Update static lib output * Get rid of old comments * Update static for buiding --- libnd4j/CMakeLists.txt | 4 +- libnd4j/blas/CMakeLists.txt | 16 +++- libnd4j/buildnativeoperations.sh | 94 +++++++++++++++---- libnd4j/cmake/android-arm.cmake | 33 +++---- libnd4j/cmake/android-arm64.cmake | 30 +++--- libnd4j/cmake/android-x86.cmake | 30 +++--- libnd4j/cmake/android-x86_64.cmake | 33 +++---- libnd4j/pom.xml | 48 ++++++++-- nd4j/compile-android.sh | 1 + .../nd4j-api-parent/nd4j-native-api/pom.xml | 5 - .../nd4j-backend-impls/nd4j-cuda/pom.xml | 2 +- .../nd4j-backend-impls/nd4j-native/pom.xml | 33 +------ nd4j/nd4j-backends/nd4j-backend-impls/pom.xml | 4 +- pom.xml | 1 + 14 files changed, 193 insertions(+), 141 deletions(-) create mode 100644 nd4j/compile-android.sh diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 712d123be..63a83c05b 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -79,10 +79,11 @@ if(NOT SD_CUDA) if ("${OPENBLAS_PATH}" STREQUAL "") #we don't want OpenBLAS on Apple if (NOT APPLE) + # note: this is not a typo set(BLA_VENDOR "OpenBLAS") endif() - # look around for system blas instead + # look around for system blas instead, see: https://cmake.org/cmake/help/latest/module/FindBLAS.html find_package(BLAS REQUIRED) if (BLAS_FOUND) message("Found external BLAS implementation: ${BLAS_LIBRARIES} ") @@ -91,6 +92,7 @@ if(NOT SD_CUDA) else() # if we have externally provided OPENBLAS_PATH - let's use it set(HAVE_OPENBLAS 1) + message("Setting openblas") include_directories(${OPENBLAS_PATH}/include/) link_directories(${OPENBLAS_PATH} ${OPENBLAS_PATH}/lib/) set(OPENBLAS_LIBRARIES openblas) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 2dccc680f..a793063bc 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -49,7 +49,7 @@ if (SD_IOS_BUILD) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DSD_IOS_BUILD=true") endif() -if(WIN32) +if(WIN32 AND NOT ANDROID) get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj") @@ -231,7 +231,11 @@ if(SD_CUDA) ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}) - add_library(${SD_LIBRARY_NAME} SHARED $) + # Don't output dynamic linked lib when a static lib build is specified unless the tests are built + if(NOT SD_STATIC_LIB OR SD_BUILD_TESTS) + add_library(${SD_LIBRARY_NAME} SHARED $) + endif() + if (WIN32) message("MSVC runtime for library: ${MSVC_RT_LIB}") @@ -241,7 +245,7 @@ if(SD_CUDA) if (SD_BUILD_TESTS OR SD_STATIC_LIB) add_library(${SD_LIBRARY_NAME}static STATIC $) set_property(TARGET ${SD_LIBRARY_NAME}static PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") - install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .) + install(TARGETS ${SD_LIBRARY_NAME}static DESTINATION .) endif() # on windows we want to make sure we use MT or MD, but since we use it in one lib, we must use it everywhere to avoid conflicts @@ -320,14 +324,16 @@ elseif(SD_CPU) ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) if(IOS) - add_library(${SD_LIBRARY_NAME} STATIC $) + add_library(${SD_LIBRARY_NAME} STATIC $) else() # static library is built only if we're going to build tests, skip otherwise if (SD_BUILD_TESTS OR SD_STATIC_LIB) add_library(${SD_LIBRARY_NAME}static STATIC $) endif() - add_library(${SD_LIBRARY_NAME} SHARED $) + if(SD_BUILD_TESTS OR NOT SD_STATIC_LIB) + add_library(${SD_LIBRARY_NAME} SHARED $) + endif() endif() # we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index 380238554..9a2c3e240 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -21,6 +21,33 @@ set -eu DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" cd "$DIR" +setwindows_msys() { + if [[ $KERNEL == *"windows"* ]]; then + export CMAKE_COMMAND="$CMAKE_COMMAND -G \"MSYS Makefiles\"" + fi +} + +setandroid_defaults() { + if [[ -z ${ANDROID_NDK:-} ]]; then + export ANDROID_NDK=$HOME/Android/android-ndk/ + echo "No ANDROID_NDK variable set. Setting to default of $ANDROID_NDK" + else + echo "USING ANDROID NDK $ANDROID_NDK" +fi + + if [[ -z ${ANDROID_VERSION:-} ]]; then + export ANDROID_VERSION=21 + echo "No ANDROID_VERSION variable set. Setting to default of $ANDROID_VERSION" + else + echo "USING ANDROID VERSION $ANDROID_VERSION" + # android needs static linking + +fi + + +} + + export CMAKE_COMMAND="cmake" if which cmake3 &> /dev/null; then export CMAKE_COMMAND="cmake3" @@ -57,7 +84,7 @@ VERBOSE_ARG="VERBOSE=1" HELPER= CHECK_VECTORIZATION="OFF" NAME= -while [[ $# > 0 ]] +while [[ $# -gt 0 ]] do key="$1" value="${2:-}" @@ -141,7 +168,7 @@ case $key in # unknown option ;; esac -if [[ $# > 0 ]]; then +if [[ $# -gt 0 ]]; then shift # past argument or value fi done @@ -154,6 +181,8 @@ if [ "$(uname)" == "Darwin" ]; then elif [ "$(expr substr $(uname -s) 1 5)" == "MINGW" ] || [ "$(expr substr $(uname -s) 1 4)" == "MSYS" ]; then HOST="windows" KERNEL="windows-x86_64" + # need to set build path separator, it ends up being wrong on msys2 + BUILD_PATH_SEPARATOR=";" echo "Running windows" elif [ "$(uname -m)" == "ppc64le" ]; then if [ -z "$ARCH" ]; then @@ -166,9 +195,9 @@ if [ -z "$OS" ]; then OS="$HOST" fi -if [[ -z ${ANDROID_NDK:-} ]]; then - export ANDROID_NDK=$HOME/Android/android-ndk/ -fi + + +echo "RUNNING BUILD FOR OS $OS" case "$OS" in linux-armhf) @@ -190,44 +219,67 @@ case "$OS" in if [ -z "$ARCH" ]; then ARCH="armv7-a" fi + + setandroid_defaults + + export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-$ANDROID_VERSION/arch-arm/" export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm.cmake -DSD_ANDROID_BUILD=true" + setwindows_msys ;; android-arm64) if [ -z "$ARCH" ]; then ARCH="armv8-a" fi + + setandroid_defaults + + echo "BUILDING ANDROID ARM with KERNEL $KERNEL" export ANDROID_BIN="$ANDROID_NDK/toolchains/aarch64-linux-android-4.9/prebuilt/$KERNEL/" export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm64/" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm64.cmake -DSD_ANDROID_BUILD=true" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-$ANDROID_VERSION/arch-arm64/" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm64.cmake -DSD_ANDROID_BUILD=true" + setwindows_msys ;; android-x86) if [ -z "$ARCH" ]; then ARCH="i686" fi - export ANDROID_BIN="$ANDROID_NDK/toolchains/x86-4.9/prebuilt/$KERNEL/" + echo "BUILDING ANDROID x86" + + setandroid_defaults + + + export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-$ANDROID_VERSION/arch-x86/" export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86.cmake -DSD_ANDROID_BUILD=true" + setwindows_msys ;; android-x86_64) + if [ -z "$ARCH" ]; then ARCH="x86-64" fi - export ANDROID_BIN="$ANDROID_NDK/toolchains/x86_64-4.9/prebuilt/$KERNEL/" + echo "BUILDING ANDROID x86_64" + + setandroid_defaults + + + export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86_64/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-$ANDROID_VERSION/arch-x86_64/" export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86_64.cmake -DSD_ANDROID_BUILD=true" + setwindows_msys ;; ios-x86_64) @@ -312,6 +364,7 @@ case "$OS" in PARALLEL="true" VERBOSE_ARG="-v" else + echo "SETTING UP WINDOWS" export CMAKE_COMMAND="cmake -G \"MSYS Makefiles\"" export MAKE_COMMAND="make" export CC=/mingw64/bin/gcc @@ -400,9 +453,9 @@ if [ -z "$NAME" ]; then fi if [ "$LIBTYPE" == "dynamic" ]; then - SHARED_LIBS_ARG="-DSD_SHARED_LIB=OFF" + SHARED_LIBS_ARG="-DSD_SHARED_LIB=ON -DSD_STATIC_LIB=OFF" else - SHARED_LIBS_ARG="-DSD_SHARED_LIB=ON" + SHARED_LIBS_ARG="-DSD_SHARED_LIB=OFF -DSD_STATIC_LIB=ON" fi if [ "$BUILD" == "release" ]; then @@ -464,11 +517,14 @@ if [ "$CHIP" == "cuda" ] && [ -n "$CHIP_VERSION" ]; then esac fi + [[ -z ${OPENBLAS_PATH:-} ]] && OPENBLAS_PATH="" +OPENBLAS_PATH="${OPENBLAS_PATH//\\//}" if [[ -n "${BUILD_PATH:-}" ]]; then PREVIFS="$IFS" IFS="$BUILD_PATH_SEPARATOR" + echo "BUILD PATH BUILD_PATH_SEPARATOR IS $BUILD_PATH_SEPARATOR" for P in $BUILD_PATH; do if [[ -f "$P/include/openblas_config.h" ]]; then OPENBLAS_PATH="$P" @@ -485,6 +541,7 @@ fi # replace any backslash with a slash OPENBLAS_PATH="${OPENBLAS_PATH//\\//}" + mkbuilddir() { if [ "$CLEAN" == "true" ]; then echo "Removing blasbuild" @@ -537,7 +594,7 @@ echo CHECK_VECTORIZATION = "$CHECK_VECTORIZATION" echo HELPERS = "$HELPERS" mkbuilddir pwd -eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DSD_CHECK_VECTORIZATION="${CHECK_VECTORIZATION}" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. +eval "$CMAKE_COMMAND" "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DSD_CHECK_VECTORIZATION="${CHECK_VECTORIZATION}" "$HELPERS" "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. if [ "$PARALLEL" == "true" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" @@ -551,9 +608,10 @@ if [ "$CHECK_VECTORIZATION" == "ON" ]; then if [ "$MAKE_COMMAND" == "make" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS --output-sync=target" fi + exec 3>&1 -eval $MAKE_COMMAND $MAKE_ARGUMENTS 2>&1 >&3 3>&- | python3 ../../auto_vectorization/auto_vect.py && cd ../../.. +eval "$MAKE_COMMAND" "$MAKE_ARGUMENTS" 2>&1 >&3 3>&- | python3 ../../auto_vectorization/auto_vect.py && cd ../../.. exec 3>&- else -eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../.. -fi +eval "$MAKE_COMMAND" "$MAKE_ARGUMENTS" && cd ../../.. +fi \ No newline at end of file diff --git a/libnd4j/cmake/android-arm.cmake b/libnd4j/cmake/android-arm.cmake index 75a3903c7..9d150b070 100644 --- a/libnd4j/cmake/android-arm.cmake +++ b/libnd4j/cmake/android-arm.cmake @@ -1,27 +1,22 @@ # CMake toolchain to build for Android 5.0 or newer. Sample usage: # -# ANDROID_BIN="/path/to/android-ndk/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/" \ -# ANDROID_CPP="/path/to/android-ndk/sources/cxx-stl/llvm-libc++/" \ -# ANDROID_CC="/path/to/android-ndk/toolchains/llvm/prebuilt/linux-x86_64/bin/clang" \ -# ANDROID_ROOT="/path/to/android-ndk/platforms/android-21/arch-arm/" \ -# cmake -DCMAKE_TOOLCHAIN_FILE=android-arm.cmake -DCMAKE_INSTALL_PREFIX=.. -# -# If you really need to use libnd4j on a CPU with no FPU, replace "libs/armeabi-v7a" by "libs/armeabi" and -# "-march=armv7-a -mfloat-abi=softfp -mfpu=vfpv3-d16" with "-march=armv5te -mtune=xscale -msoft-float" +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_ANDROID_ARCH_ABI arm64-v8a) +set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") +set(CMAKE_ANDROID_STL_TYPE c++_shared) +set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") +set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION clang) -set(CMAKE_SYSTEM_NAME UnixPaths) -set(CMAKE_SYSTEM_PROCESSOR arm) set(ANDROID TRUE) +if (WIN32) + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}.exe") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++.exe") + else() + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") +endif (WIN32) -set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") -set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") -set(CMAKE_C_LINK_EXECUTABLE " -target armv7-none-linux-androideabi -Wl,--fix-cortex-a8 -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_LINK_EXECUTABLE " -target armv7-none-linux-androideabi -Wl,--fix-cortex-a8 -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/armeabi-v7a/ -nostdlib++ -lc++_static -lc++abi -landroid_support -lm -lc") -set(CMAKE_C_CREATE_SHARED_LIBRARY " -target armv7-none-linux-androideabi -Wl,--fix-cortex-a8 -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_CREATE_SHARED_LIBRARY " -target armv7-none-linux-androideabi -Wl,--fix-cortex-a8 -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/armeabi-v7a/ -nostdlib++ -lc++_static -lc++abi -landroid_support -lm -lc") +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target aarch64-none-linux-android -march=armv8-a) -add_definitions(-D__ANDROID_API__=21 -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target armv7-none-linux-androideabi -march=armv7-a -mfloat-abi=softfp -mfpu=vfpv3-d16) - -include_directories("$ENV{ANDROID_CPP}/include/" "$ENV{ANDROID_CPP}/../llvm-libc++abi/include/" "$ENV{ANDROID_NDK}/sources/android/support/include/" "$ENV{ANDROID_CPP}/libs/armeabi-v7a/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/arm-linux-androideabi/" "$ENV{ANDROID_ROOT}/usr/include/") diff --git a/libnd4j/cmake/android-arm64.cmake b/libnd4j/cmake/android-arm64.cmake index abc649cb4..9d150b070 100644 --- a/libnd4j/cmake/android-arm64.cmake +++ b/libnd4j/cmake/android-arm64.cmake @@ -1,24 +1,22 @@ # CMake toolchain to build for Android 5.0 or newer. Sample usage: # -# ANDROID_BIN="/path/to/android-ndk/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/" \ -# ANDROID_CPP="/path/to/android-ndk/sources/cxx-stl/llvm-libc++/" \ -# ANDROID_CC="/path/to/android-ndk/toolchains/llvm/prebuilt/linux-x86_64/bin/clang" \ -# ANDROID_ROOT="/path/to/android-ndk/platforms/android-21/arch-arm64/" \ -# cmake -DCMAKE_TOOLCHAIN_FILE=android-arm64.cmake -DCMAKE_INSTALL_PREFIX=.. +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_ANDROID_ARCH_ABI arm64-v8a) +set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") +set(CMAKE_ANDROID_STL_TYPE c++_shared) +set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") +set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION clang) -set(CMAKE_SYSTEM_NAME UnixPaths) -set(CMAKE_SYSTEM_PROCESSOR arm64) set(ANDROID TRUE) +if (WIN32) + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}.exe") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++.exe") + else() + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") +endif (WIN32) -set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") -set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") -set(CMAKE_C_LINK_EXECUTABLE " -target aarch64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_LINK_EXECUTABLE " -target aarch64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/arm64-v8a/ -nostdlib++ -lc++_static -lc++abi -lm -lc") -set(CMAKE_C_CREATE_SHARED_LIBRARY " -target aarch64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_CREATE_SHARED_LIBRARY " -target aarch64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/arm64-v8a/ -nostdlib++ -lc++_static -lc++abi -lm -lc") +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target aarch64-none-linux-android -march=armv8-a) -add_definitions(-D__ANDROID_API__=21 -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target aarch64-none-linux-android -march=armv8-a) - -include_directories("$ENV{ANDROID_CPP}/include/" "$ENV{ANDROID_CPP}/../llvm-libc++abi/include/" "$ENV{ANDROID_NDK}/sources/android/support/include/" "$ENV{ANDROID_CPP}/libs/arm64-v8a/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/aarch64-linux-android/" "$ENV{ANDROID_ROOT}/usr/include/") diff --git a/libnd4j/cmake/android-x86.cmake b/libnd4j/cmake/android-x86.cmake index 6065161aa..7c3297b74 100644 --- a/libnd4j/cmake/android-x86.cmake +++ b/libnd4j/cmake/android-x86.cmake @@ -1,24 +1,22 @@ # CMake toolchain to build for Android 5.0 or newer. Sample usage: # -# ANDROID_BIN="/path/to/android-ndk/toolchains/x86-4.9/prebuilt/linux-x86_64/" \ -# ANDROID_CPP="/path/to/android-ndk/sources/cxx-stl/llvm-libc++/" \ -# ANDROID_CC="/path/to/android-ndk/toolchains/llvm/prebuilt/linux-x86_64/bin/clang" \ -# ANDROID_ROOT="/path/to/android-ndk/platforms/android-21/arch-x86/" \ -# cmake -DCMAKE_TOOLCHAIN_FILE=android-x86.cmake -DCMAKE_INSTALL_PREFIX=.. +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_ANDROID_ARCH_ABI x86) +set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") +set(CMAKE_ANDROID_STL_TYPE c++_shared) +set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") +set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION clang) -set(CMAKE_SYSTEM_NAME UnixPaths) -set(CMAKE_SYSTEM_PROCESSOR atom) set(ANDROID TRUE) +if (WIN32) + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}.exe") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++.exe") + else() + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") +endif (WIN32) -set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") -set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") -set(CMAKE_C_LINK_EXECUTABLE " -target i686-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_LINK_EXECUTABLE " -target i686-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/x86/ -nostdlib++ -lc++_static -lc++abi -landroid_support -lm -lc") -set(CMAKE_C_CREATE_SHARED_LIBRARY " -target i686-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_CREATE_SHARED_LIBRARY " -target i686-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/x86/ -nostdlib++ -lc++_static -lc++abi -landroid_support -lm -lc") +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target x86-none-linux-android) -add_definitions(-D__ANDROID_API__=21 -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target i686-none-linux-android -march=i686 -mtune=atom -mssse3 -mfpmath=sse) - -include_directories("$ENV{ANDROID_CPP}/include/" "$ENV{ANDROID_CPP}/../llvm-libc++abi/include/" "$ENV{ANDROID_NDK}/sources/android/support/include/" "$ENV{ANDROID_CPP}/libs/x86/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/i686-linux-android/" "$ENV{ANDROID_ROOT}/usr/include/") diff --git a/libnd4j/cmake/android-x86_64.cmake b/libnd4j/cmake/android-x86_64.cmake index e249b3154..5ff797910 100644 --- a/libnd4j/cmake/android-x86_64.cmake +++ b/libnd4j/cmake/android-x86_64.cmake @@ -1,24 +1,21 @@ # CMake toolchain to build for Android 5.0 or newer. Sample usage: -# -# ANDROID_BIN="/path/to/android-ndk/toolchains/x86_64-4.9/prebuilt/linux-x86_64/" \ -# ANDROID_CPP="/path/to/android-ndk/sources/cxx-stl/llvm-libc++/" \ -# ANDROID_CC="/path/to/android-ndk/toolchains/llvm/prebuilt/linux-x86_64/bin/clang" \ -# ANDROID_ROOT="/path/to/android-ndk/platforms/android-21/arch-x86_64/" \ -# cmake -DCMAKE_TOOLCHAIN_FILE=android-x86_64.cmake -DCMAKE_INSTALL_PREFIX=.. -set(CMAKE_SYSTEM_NAME UnixPaths) -set(CMAKE_SYSTEM_PROCESSOR atom64) +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_ANDROID_ARCH_ABI x86_64) +set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") +set(CMAKE_ANDROID_STL_TYPE c++_shared) +set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") +set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION clang) + set(ANDROID TRUE) +if (WIN32) + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}.exe") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++.exe") + else() + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") +endif (WIN32) -set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") -set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") -set(CMAKE_C_LINK_EXECUTABLE " -target x86_64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_LINK_EXECUTABLE " -target x86_64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/x86_64/ -nostdlib++ -lc++_static -lc++abi -lm -lc") -set(CMAKE_C_CREATE_SHARED_LIBRARY " -target x86_64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -lm -lc") -set(CMAKE_CXX_CREATE_SHARED_LIBRARY " -target x86_64-none-linux-android -Wl,--no-undefined -z text -o -gcc-toolchain $ENV{ANDROID_BIN} --sysroot=$ENV{ANDROID_ROOT} -L$ENV{ANDROID_CPP}/libs/x86_64/ -nostdlib++ -lc++_static -lc++abi -lm -lc") - -add_definitions(-D__ANDROID_API__=21 -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target x86_64-none-linux-android -march=x86-64 -mtune=atom) - -include_directories("$ENV{ANDROID_CPP}/include/" "$ENV{ANDROID_CPP}/../llvm-libc++abi/include/" "$ENV{ANDROID_NDK}/sources/android/support/include/" "$ENV{ANDROID_CPP}/libs/x86_64/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/" "$ENV{ANDROID_NDK}/sysroot/usr/include/x86_64-linux-android/" "$ENV{ANDROID_ROOT}/usr/include/") +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target x86_64-none-linux-android) diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 20b9d6562..8db086245 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -17,8 +17,8 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> org.deeplearning4j @@ -75,6 +75,7 @@ ${libnd4j.platform} + bash @@ -138,7 +139,7 @@ javacpp-cppbuild-validate validate - build + build @@ -150,8 +151,8 @@ ${libnd4j.cpu.compile.skip} - bash - ${project.basedir}/buildnativeoperations.sh + ${libnd4j.buildprogram} + buildnativeoperations.sh --build-type ${libnd4j.build} --chip @@ -183,7 +184,7 @@ ${libnd4j.test.skip} ${basedir}/tests_cpu - bash + sh run_tests.sh --chip ${libnd4j.chip} @@ -233,9 +234,35 @@ + + + + build-windows + + + Windows + + + + sh + + + + + + build-unix + + true + + + bash + + + + - libnd4-single-thread + libnd4j-single-thread libnd4j.singlethread @@ -310,8 +337,8 @@ ${libnd4j.cuda.compile.skip} - bash - ${project.basedir}/buildnativeoperations.sh + ${libnd4j.buildprogram} + buildnativeoperations.sh --build-type ${libnd4j.build} @@ -389,6 +416,9 @@ + + + libnd4j-helper-avx2 diff --git a/nd4j/compile-android.sh b/nd4j/compile-android.sh new file mode 100644 index 000000000..da7fa1799 --- /dev/null +++ b/nd4j/compile-android.sh @@ -0,0 +1 @@ +mvn clean install -Djavacpp.platform=android-arm64 -Dmaven.test.skip=true -Djavacpp.platform.compiler=$ANDROID_NDK/toolchains/llvm/prebuilt/windows-x86_64/bin/clang++ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml index 90748460a..c20e4553a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml @@ -30,11 +30,6 @@ - org.nd4j nd4j-api diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index 46566f50b..14cb6af6e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -29,7 +29,7 @@ 10.2 7.6 - 1.5.2 + ${javacpp-presets.version} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index 379f69563..1b9af96b2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -51,16 +51,6 @@ ${dependency.platform} - org.nd4j nd4j-native-api @@ -165,8 +155,8 @@ /${javacpp.platform.library.path}/ /${javacpp.platform.library.path}/lib/ - /org/bytedeco/openblas/${javacpp.platform}/ - /org/bytedeco/openblas/${javacpp.platform}/lib/ + @@ -174,7 +164,7 @@ javacpp-validate validate - build + build @@ -297,23 +287,6 @@ -avx512 - - mingw - - windows - - - - - org.bytedeco - javacpp - - ${javacpp.platform}-mingw - - - - - libnd4j-assembly diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml index c355bf0a5..a47c601b4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml @@ -56,9 +56,7 @@ - + diff --git a/pom.xml b/pom.xml index a2188de1a..898e666a1 100644 --- a/pom.xml +++ b/pom.xml @@ -440,6 +440,7 @@ + doclint-java8-disable From 30a28fae458d21545dc1585dd5083bbc1030178e Mon Sep 17 00:00:00 2001 From: Adam Gibson <1144306+agibsonccc@users.noreply.github.com> Date: Fri, 20 Mar 2020 12:14:03 +0900 Subject: [PATCH 02/17] Windows fix (#333) * Fix cmake detection in msys * Revert windows change * Update to unix line endings --- libnd4j/buildnativeoperations.sh | 18 +++----- libnd4j/cmake/android-arm64.cmake | 44 +++++++++---------- .../nd4j-backend-impls/nd4j-native/pom.xml | 19 ++++++++ 3 files changed, 46 insertions(+), 35 deletions(-) diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index 9a2c3e240..af9154866 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -181,8 +181,6 @@ if [ "$(uname)" == "Darwin" ]; then elif [ "$(expr substr $(uname -s) 1 5)" == "MINGW" ] || [ "$(expr substr $(uname -s) 1 4)" == "MSYS" ]; then HOST="windows" KERNEL="windows-x86_64" - # need to set build path separator, it ends up being wrong on msys2 - BUILD_PATH_SEPARATOR=";" echo "Running windows" elif [ "$(uname -m)" == "ppc64le" ]; then if [ -z "$ARCH" ]; then @@ -195,9 +193,9 @@ if [ -z "$OS" ]; then OS="$HOST" fi - - -echo "RUNNING BUILD FOR OS $OS" +if [[ -z ${ANDROID_NDK:-} ]]; then + export ANDROID_NDK=$HOME/Android/android-ndk/ +fi case "$OS" in linux-armhf) @@ -251,11 +249,8 @@ case "$OS" in if [ -z "$ARCH" ]; then ARCH="i686" fi - echo "BUILDING ANDROID x86" setandroid_defaults - - export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang" @@ -364,7 +359,6 @@ case "$OS" in PARALLEL="true" VERBOSE_ARG="-v" else - echo "SETTING UP WINDOWS" export CMAKE_COMMAND="cmake -G \"MSYS Makefiles\"" export MAKE_COMMAND="make" export CC=/mingw64/bin/gcc @@ -524,7 +518,6 @@ OPENBLAS_PATH="${OPENBLAS_PATH//\\//}" if [[ -n "${BUILD_PATH:-}" ]]; then PREVIFS="$IFS" IFS="$BUILD_PATH_SEPARATOR" - echo "BUILD PATH BUILD_PATH_SEPARATOR IS $BUILD_PATH_SEPARATOR" for P in $BUILD_PATH; do if [[ -f "$P/include/openblas_config.h" ]]; then OPENBLAS_PATH="$P" @@ -541,7 +534,6 @@ fi # replace any backslash with a slash OPENBLAS_PATH="${OPENBLAS_PATH//\\//}" - mkbuilddir() { if [ "$CLEAN" == "true" ]; then echo "Removing blasbuild" @@ -609,9 +601,9 @@ if [ "$MAKE_COMMAND" == "make" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS --output-sync=target" fi -exec 3>&1 +exec 3>&1 eval "$MAKE_COMMAND" "$MAKE_ARGUMENTS" 2>&1 >&3 3>&- | python3 ../../auto_vectorization/auto_vect.py && cd ../../.. exec 3>&- else eval "$MAKE_COMMAND" "$MAKE_ARGUMENTS" && cd ../../.. -fi \ No newline at end of file +fi diff --git a/libnd4j/cmake/android-arm64.cmake b/libnd4j/cmake/android-arm64.cmake index 9d150b070..33ee454e7 100644 --- a/libnd4j/cmake/android-arm64.cmake +++ b/libnd4j/cmake/android-arm64.cmake @@ -1,22 +1,22 @@ -# CMake toolchain to build for Android 5.0 or newer. Sample usage: -# -set(CMAKE_SYSTEM_NAME Android) -set(CMAKE_ANDROID_ARCH_ABI arm64-v8a) -set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") -set(CMAKE_ANDROID_STL_TYPE c++_shared) -set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") -set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION clang) - -set(ANDROID TRUE) -if (WIN32) - set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}.exe") - set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++.exe") - else() - set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") - set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") -endif (WIN32) - - - -add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target aarch64-none-linux-android -march=armv8-a) - +# CMake toolchain to build for Android 5.0 or newer. Sample usage: +# +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_ANDROID_ARCH_ABI arm64-v8a) +set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") +set(CMAKE_ANDROID_STL_TYPE c++_shared) +set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") +set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION clang) + +set(ANDROID TRUE) +if (WIN32) + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}.exe") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++.exe") + else() + set(CMAKE_C_COMPILER "$ENV{ANDROID_CC}") + set(CMAKE_CXX_COMPILER "$ENV{ANDROID_CC}++") +endif (WIN32) + + + +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target aarch64-none-linux-android -march=armv8-a) + diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index 1b9af96b2..a964f6918 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -287,6 +287,25 @@ -avx512 + + mingw + + + windows + + + + + + org.bytedeco + javacpp + + ${javacpp.platform}-mingw + + + + + libnd4j-assembly From 7a2ac800dd995919ebf9a99c027061fd21f57cb4 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 20 Mar 2020 08:49:28 +0300 Subject: [PATCH 03/17] Nullify (#304) * initial commit Signed-off-by: raver119 * bunch of tweaks Signed-off-by: raver119 * hamming distance nullification Signed-off-by: raver119 * Add output array value assignment for testing/debugging Signed-off-by: Alex Black * don't assign empty arrays Signed-off-by: raver119 * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 * few more fixes Signed-off-by: raver119 * im2col Signed-off-by: raver119 * pooling? Signed-off-by: raver119 * more nullified Signed-off-by: raver119 * ismax nullified Signed-off-by: raver119 * rollback ismax nullification Signed-off-by: raver119 * synchronized cublas handle use on per-device basis Signed-off-by: raver119 * hiding method from jcpp Signed-off-by: raver119 * get rid of test assigns in DeclarableOp Signed-off-by: raver119 * get rid of assigns Signed-off-by: raver119 * proper deviceId is back Signed-off-by: raver119 * include fixed Signed-off-by: raver119 Co-authored-by: Alex Black --- libnd4j/include/array/NDArray.h | 4 +- libnd4j/include/array/NDArray.hXX | 11 ++-- libnd4j/include/execution/LaunchContext.h | 9 ++- .../include/execution/cpu/LaunchContext.cpp | 9 +++ .../include/execution/cuda/LaunchContext.cu | 8 +++ .../include/helpers/cuda_off/MmulHelper.cu | 4 ++ libnd4j/include/ops/declarable/DeclarableOp.h | 1 + .../include/ops/declarable/PlatformHelper.h | 10 +++- .../generic/bitwise/bits_hamming_distance.cpp | 2 +- .../generic/compat/compat_sparse_to_dense.cpp | 2 +- .../generic/compat/compat_string_split.cpp | 2 +- .../declarable/generic/nn/convo/col2im.cpp | 4 +- .../declarable/generic/nn/convo/conv1d.cpp | 8 +-- .../declarable/generic/nn/convo/conv2d.cpp | 10 ++-- .../declarable/generic/nn/convo/deconv2d.cpp | 2 +- .../generic/nn/convo/deconv2d_tf.cpp | 2 +- .../generic/nn/convo/depthwiseConv2d.cpp | 8 +-- .../declarable/generic/nn/convo/im2col.cpp | 7 +-- .../declarable/generic/nn/convo/sconv2d.cpp | 14 ++--- .../generic/nn/convo/upsampling2d.cpp | 4 +- .../generic/nn/convo/upsampling3d.cpp | 4 +- .../generic/nn/pooling/avgpool2d.cpp | 4 +- .../generic/nn/pooling/avgpool3d.cpp | 4 +- .../generic/nn/pooling/maxpool2d.cpp | 4 +- .../generic/nn/pooling/maxpool3d.cpp | 4 +- .../nn/pooling/maxpool_with_argmax.cpp | 6 +- .../generic/nn/pooling/pnormpool2d.cpp | 4 +- .../declarable/generic/parity_ops/dropout.cpp | 4 +- .../declarable/generic/parity_ops/lstsq.cpp | 4 +- .../generic/parity_ops/matrix_determinant.cpp | 2 +- .../generic/parity_ops/segment_max.cpp | 4 +- .../generic/parity_ops/segment_mean.cpp | 4 +- .../generic/parity_ops/segment_min.cpp | 4 +- .../generic/parity_ops/segment_prod.cpp | 4 +- .../generic/parity_ops/segment_sum.cpp | 2 +- .../generic/parity_ops/sequence_mask.cpp | 2 +- .../parity_ops/unsorted_segment_max.cpp | 4 +- .../parity_ops/unsorted_segment_mean.cpp | 4 +- .../parity_ops/unsorted_segment_min.cpp | 4 +- .../parity_ops/unsorted_segment_prod.cpp | 4 +- .../parity_ops/unsorted_segment_sqrt_n.cpp | 4 +- .../parity_ops/unsorted_segment_sum.cpp | 4 +- .../declarable/generic/random/multinomial.cpp | 2 +- .../generic/thrid_party/firas_sparse.cpp | 2 +- .../declarable/generic/tsne/edge_force.cpp | 2 +- .../ops/declarable/helpers/cpu/col2im.cpp | 2 - .../declarable/helpers/cuda/batched_gemm.cu | 2 + .../ops/declarable/impl/DeclarableOp.cpp | 11 +++- .../ops/declarable/impl/PlatformHelper.cpp | 10 +++- .../ops/declarable/platform/mkldnn/conv2d.cpp | 6 +- .../ops/declarable/platform/mkldnn/conv3d.cpp | 6 +- .../platform/mkldnn/depthwiseConv2d.cpp | 6 +- libnd4j/include/system/op_boilerplate.h | 1 + .../layers_tests/ConvolutionTests1.cpp | 6 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 59 +++++++++++-------- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 59 +++++++++++-------- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 2 + 57 files changed, 229 insertions(+), 152 deletions(-) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 3fbfcef8e..6ab301200 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -277,13 +277,13 @@ namespace sd { /** * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently */ - NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true); /** * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * set dtype as array type */ - NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true); /** * this constructor creates new array using shape information contained in vector argument diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index fa39a00f6..43c6fe2ad 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -143,7 +143,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector &sh //////////////////////////////////////////////////////////////////////// // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros -NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context) { +NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) { if (shapeInfo == nullptr) throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo"); @@ -161,7 +161,9 @@ NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyS if (!isEmpty()) { _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); + + if (nullify) + _buffer->setToZeroBuffers(); } } @@ -213,7 +215,7 @@ NDArray::NDArray(sd::LaunchContext * context) { //////////////////////////////////////////////////////////////////////// // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type -NDArray::NDArray(Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context): +NDArray::NDArray(Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context, const bool nullify): NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) { } @@ -3339,9 +3341,6 @@ void NDArray::nullify() { if (isEmpty()) return; - if (isS()) - throw std::runtime_error("NDArray::nullify: can't nullify string array"); - if (isView() || ews() != 1) assign(0); else diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index e2efa1418..4eaf2ca0f 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -54,6 +54,8 @@ class ND4J_EXPORT LaunchContext { static std::vector> _contexts; static std::mutex _mutex; + static MAP_IMPL _deviceMutexes; + // used for MKLDNN void *_engine = nullptr; @@ -93,7 +95,6 @@ class ND4J_EXPORT LaunchContext { void setCudaSpecialStream(cudaStream_t* cudaStream); void setCublasHandle(void *handle); - #endif // JCPP #endif // CUDA @@ -111,6 +112,12 @@ class ND4J_EXPORT LaunchContext { void setDeviceID(int deviceID) { _deviceID = deviceID; } sd::ErrorReference* errorReference(); +#ifndef __JAVACPP_HACK__ + // this method returns mutex shared between all threads that use the same device + static std::mutex* deviceMutex(); + +#endif + static bool isInitialized(); static void releaseBuffers(); diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 6217e0707..23df9f9f1 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -19,6 +19,7 @@ // #include +#include #include #include #include @@ -42,6 +43,7 @@ namespace sd { } std::vector> LaunchContext::_contexts = std::vector>(); + MAP_IMPL LaunchContext::_deviceMutexes; //////////////////////////////////////////////////////////////////////// LaunchContext::LaunchContext() { @@ -49,6 +51,8 @@ namespace sd { _workspace = nullptr; _deviceID = 0; + _deviceMutexes[_deviceID] = new std::mutex(); + #ifdef HAVE_MKLDNN _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); #endif @@ -68,6 +72,11 @@ namespace sd { return LaunchContext::_contexts[0].get(); } + std::mutex* LaunchContext::deviceMutex() { + auto deviceId = AffinityManager::currentDeviceId(); + return _deviceMutexes[deviceId]; + } + void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { // } diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 28193c3b0..8380e50bf 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -31,6 +31,7 @@ namespace sd { std::vector> LaunchContext::_contexts = std::vector>(); std::mutex LaunchContext::_mutex; + MAP_IMPL LaunchContext::_deviceMutexes; //////////////////////////////////////////////////////////////////////// LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { @@ -44,6 +45,11 @@ LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCuda _isAllocated = false; } + std::mutex* LaunchContext::deviceMutex() { + auto deviceId = AffinityManager::currentDeviceId(); + return _deviceMutexes[deviceId]; + } + LaunchContext::~LaunchContext() { if (_isAllocated) { @@ -85,6 +91,8 @@ LaunchContext::LaunchContext() { _contexts.resize(numDevices); for (int e = 0; e < numDevices; e++) { + _deviceMutexes[e] = new std::mutex(); + AffinityManager::setCurrentNativeDevice(e); LaunchContext::_contexts[e] = std::make_shared(); diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 5e9304e88..fd1cd5813 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -252,6 +252,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; + std::lock_guard lock(*LaunchContext::deviceMutex()); + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); auto stream = A->getContext()->getCudaStream(); @@ -394,6 +396,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, const bool typeDouble = AXY && aType == DataType::DOUBLE; const bool typeFloat = AXY && aType == DataType::FLOAT32; + std::lock_guard lock(*LaunchContext::deviceMutex()); + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); auto stream = A->getContext()->getCudaStream(); diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 5c01eecc8..3cce3b8e4 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -106,6 +106,7 @@ namespace sd { void storeResult(Context &block, int outputNumber, NDArray& array); void storeResult(Context &block, int outputNumber, NDArray* array); sd::NDArray* getZ(Context& block, int inputId = 0); + sd::NDArray* getNullifiedZ(Context& block, int inputId = 0); /** * This method pre-allocates NDArrays for Op output, in case they are not available at op execution time diff --git a/libnd4j/include/ops/declarable/PlatformHelper.h b/libnd4j/include/ops/declarable/PlatformHelper.h index b34a936ee..e0231ad9a 100644 --- a/libnd4j/include/ops/declarable/PlatformHelper.h +++ b/libnd4j/include/ops/declarable/PlatformHelper.h @@ -77,7 +77,15 @@ namespace sd { * @param inputId * @return */ - sd::NDArray *getZ(graph::Context &ctx, int inputId); + sd::NDArray* getZ(graph::Context &ctx, int inputId); + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + sd::NDArray* getNullifiedZ(graph::Context &ctx, int inputId); }; } } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp index 10f7095e0..65f81b428 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp @@ -30,7 +30,7 @@ namespace sd { CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { auto x = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length"); REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type"); diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp index 2d6cf5f12..95dbdfcea 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp @@ -32,7 +32,7 @@ namespace sd { auto values = INPUT_VARIABLE(2); NDArray *def = nullptr; - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); if (block.width() > 3) def = INPUT_VARIABLE(3); diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp index a59e3f02c..f88710904 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -30,7 +30,7 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto delim = INPUT_VARIABLE(1); - auto indices = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_NULLIFIED(0); auto values = OUTPUT_VARIABLE(1); auto d = delim->e(0); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp index b68c4c211..d6e95a582 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp @@ -28,7 +28,7 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) { auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf()); REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf()); @@ -45,8 +45,6 @@ namespace sd { LaunchContext* ctx = block.launchContext(); helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX); - STORE_RESULT(*z); - return ND4J_STATUS_OK; } DECLARE_SHAPE_FN(col2im) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index bb32b1780..da711a569 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) + auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width int sW = INT_ARG(1); // strides width @@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width int sW = INT_ARG(1); // strides width diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index 29b6777ec..ace83e60c 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -40,7 +40,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides width @@ -161,9 +161,9 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width @@ -267,7 +267,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 9f94d1459..12c1a9d3f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp index 70fc46e0c..5503019f4 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp @@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index c04bcf6dd..2bbcebb28 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); @@ -152,9 +152,9 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp index 179dd3005..2e5818c56 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp @@ -30,8 +30,7 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) { auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - + auto z = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf()); REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf()); @@ -53,8 +52,6 @@ namespace sd { LaunchContext* ctx = block.launchContext(); sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext())); - STORE_RESULT(*z); - return Status::OK(); } @@ -107,7 +104,7 @@ namespace sd { CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) { auto input = INPUT_VARIABLE(0); auto gradAtOutput = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf()); REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index 928643493..b09f29101 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC - NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) if(block.width() == 3) { if((INPUT_VARIABLE(2))->rankOf() == 4) @@ -199,26 +199,26 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - NDArray *gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - NDArray *gradWD = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always + NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always NDArray *gradB = nullptr; // [oC] if(block.width() == 4) { if((INPUT_VARIABLE(3))->rankOf() == 4) { weightsPoint = INPUT_VARIABLE(3); - gradWP = OUTPUT_VARIABLE(2); + gradWP = OUTPUT_NULLIFIED(2); } else { bias = INPUT_VARIABLE(3); - gradB = OUTPUT_VARIABLE(2); + gradB = OUTPUT_NULLIFIED(2); } } else if(block.width() == 5) { weightsPoint = INPUT_VARIABLE(3); bias = INPUT_VARIABLE(4); - gradWP = OUTPUT_VARIABLE(2); - gradB = OUTPUT_VARIABLE(3); + gradWP = OUTPUT_NULLIFIED(2); + gradB = OUTPUT_NULLIFIED(3); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp index 4f04eb921..4800b3db9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp @@ -32,7 +32,7 @@ namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) { auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - auto output = OUTPUT_VARIABLE(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) const int factorH = INT_ARG(0); const int factorW = INT_ARG(1); @@ -97,7 +97,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) { // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - auto gradI = OUTPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp index f88f3705f..557468d14 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp @@ -31,7 +31,7 @@ namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) { auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - auto output = OUTPUT_VARIABLE(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) const int factorD = INT_ARG(0); const int factorH = INT_ARG(1); @@ -97,7 +97,7 @@ DECLARE_SHAPE_FN(upsampling3d) { CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) { // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - auto gradI = OUTPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index 406f330ab..b15879df4 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -31,7 +31,7 @@ namespace ops { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; @@ -147,7 +147,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index 76a7377a0..30d03c907 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -32,7 +32,7 @@ namespace ops { CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) int kD = INT_ARG(0); // filter(kernel) depth int kH = INT_ARG(1); // filter(kernel) height @@ -149,7 +149,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon const int kD = INT_ARG(0); // filter(kernel) depth const int kH = INT_ARG(1); // filter(kernel) height diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index 9e7115a73..13d65a681 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); const int kH = INT_ARG(0); const int kW = INT_ARG(1); @@ -150,7 +150,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp index 2804d5a7f..37cb34cb0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp @@ -32,7 +32,7 @@ namespace ops { CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) int kD = INT_ARG(0); // filter(kernel) depth int kH = INT_ARG(1); // filter(kernel) height @@ -151,7 +151,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon const int kD = INT_ARG(0); // filter(kernel) depth const int kH = INT_ARG(1); // filter(kernel) height diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index fabfd9bad..111846584 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -30,14 +30,14 @@ namespace sd { CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) { auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - auto indeces = OUTPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + auto indices = OUTPUT_NULLIFIED(1); REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf()); auto argI = *(block.getIArguments()); - helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indeces); + helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp index 746f74da2..2c5fa66c1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp @@ -32,7 +32,7 @@ namespace sd { REQUIRE_OK(this->validateInputLengthMatch(block)); REQUIRE_OK(this->validateInputDimensionsMatch(block)); auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); @@ -145,7 +145,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dropout.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/dropout.cpp index 79c3c5cde..b64fd49d5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/dropout.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/dropout.cpp @@ -33,7 +33,7 @@ CONFIGURABLE_OP_IMPL(dropout, 1, 1, true, 1, 1) { auto input = INPUT_VARIABLE(0); // lookup param NDArray *reduceShape = nullptr; // this param is optional - auto output = OUTPUT_VARIABLE(0); // + auto output = OUTPUT_NULLIFIED(0); // int seed = INT_ARG(0); @@ -66,7 +66,7 @@ CONFIGURABLE_OP_IMPL(dropout_bp, 2, 1, false, 1, 1) { NDArray* gradOut = INPUT_VARIABLE(1); // lookup param NDArray* reduceShape = nullptr; // this param is optional - NDArray* output = OUTPUT_VARIABLE(0); // + NDArray* output = OUTPUT_NULLIFIED(0); // int seed = INT_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lstsq.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lstsq.cpp index df55db586..6b02f6d70 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/lstsq.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lstsq.cpp @@ -30,7 +30,7 @@ namespace sd { CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) { auto a = INPUT_VARIABLE(0); auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); bool fastFlag = true; double l2_factor = 0.; if (block.numB() > 0) { @@ -56,7 +56,7 @@ namespace sd { CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) { auto a = INPUT_VARIABLE(0); auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); bool fastFlag = true; double l2_factor = 0.; if (block.numB() > 0) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_determinant.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/matrix_determinant.cpp index 6c1fd40fe..2268a9e9c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_determinant.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/matrix_determinant.cpp @@ -114,7 +114,7 @@ namespace sd { CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf()); REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp index fe469e5ec..7ab19668a 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp @@ -76,8 +76,8 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto indices = INPUT_VARIABLE(1); auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - auto outIndices = OUTPUT_VARIABLE(1); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); outIndices->assign(indices); return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp index ef35f4839..abb865d8e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp @@ -76,8 +76,8 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto indices = INPUT_VARIABLE(1); auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - auto outIndices = OUTPUT_VARIABLE(1); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); outIndices->assign(indices); return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp index 9c4a11255..a245b000b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp @@ -66,8 +66,8 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto indices = INPUT_VARIABLE(1); auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - auto outIndices = OUTPUT_VARIABLE(1); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); outIndices->assign(indices); return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp index 576fae508..478eb9e23 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp @@ -67,8 +67,8 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto indices = INPUT_VARIABLE(1); auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - auto outIndices = OUTPUT_VARIABLE(1); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); outIndices->assign(indices); helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp index 203797e34..bb959fd3d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp @@ -65,7 +65,7 @@ namespace sd { CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) { - return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_VARIABLE(0)); + return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_NULLIFIED(0)); } DECLARE_SHAPE_FN(segment_sum_bp){ Nd4jLong* in = inputShape->at(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index 310180685..6b0402ebb 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -25,7 +25,7 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); const int inRank = input->rankOf(); //REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp index 14bee3853..77e851104 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp @@ -26,7 +26,7 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); @@ -67,7 +67,7 @@ namespace sd { } CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); + return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } DECLARE_TYPES(unsorted_segment_max_bp) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp index 38032c0d7..cad59b7e9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp @@ -26,7 +26,7 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); @@ -69,7 +69,7 @@ namespace sd { } CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); + return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } DECLARE_TYPES(unsorted_segment_mean_bp) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index 5d6c58b16..87b96e844 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -26,7 +26,7 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); @@ -69,7 +69,7 @@ namespace sd { } CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); + return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } DECLARE_TYPES(unsorted_segment_min_bp) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp index 1bb1d5bf5..e430c8f77 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp @@ -26,7 +26,7 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); @@ -72,7 +72,7 @@ namespace sd { auto indices = INPUT_VARIABLE(1); auto eps = INPUT_VARIABLE(2); // auto numOfClasses = INT_ARG(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e(0) : INT_ARG(0); REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp index e29a86a42..eeaa6e2c2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp @@ -26,7 +26,7 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); @@ -68,7 +68,7 @@ namespace sd { } CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); + return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { getOpDescriptor() diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp index 89c6cb76f..941496424 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp @@ -26,7 +26,7 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); @@ -67,7 +67,7 @@ namespace sd { return SHAPELIST(CONSTANT(outputShape)); } CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); + return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } DECLARE_SHAPE_FN(unsorted_segment_sum_bp){ diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp index c86417ef0..5361d1bbb 100644 --- a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -42,7 +42,7 @@ namespace sd { CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); auto inputSamples = INPUT_VARIABLE(1); diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index d385c2fa9..7860036ed 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -48,7 +48,7 @@ namespace sd { */ CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) { auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); int batchSize = x->sizeAt(0); int numColumns = x->sizeAt(1); diff --git a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp index 1d409c51f..64be499fb 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp @@ -34,7 +34,7 @@ namespace ops { auto dataP = INPUT_VARIABLE(3); auto N = INT_ARG(0); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf()); REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf()); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp index db6d27ffd..cf46df2db 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp @@ -54,8 +54,6 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output const Nd4jLong imStride1 = imStride[1]; const Nd4jLong imStride2 = imStride[2]; const Nd4jLong imStride3 = imStride[3]; - - memset(imBuff, 0, shape::length(imShapeBuffer) * sizeof(T)); // if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu index 3e3110b0e..b5447b411 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu @@ -116,6 +116,8 @@ void bgemm(const std::vector& vA, const std::vector& vB, std const auto bType = pB[0]->dataType(); const auto cType = pC[0]->dataType(); + std::lock_guard lock(*LaunchContext::deviceMutex()); + auto handle = reinterpret_cast(context->getCublasHandle()); auto stream = context->getCudaStream(); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 2c1d52348..44fbaae42 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -96,6 +96,14 @@ namespace sd { return _descriptor->getHash(); } + sd::NDArray* sd::ops::DeclarableOp::getNullifiedZ(Context& block, int inputId) { + auto result = getZ(block, inputId); + if (result != nullptr && !block.isInplace()) + result->nullify(); + + return result; + } + sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) { NDArray* z = nullptr; @@ -294,7 +302,8 @@ namespace sd { if (Environment::getInstance()->isDebugAndVerbose()) shape::printShapeInfoLinear("Going to create variable with shape", out); - auto outArr = new NDArray(out, true, ctx.launchContext()); + // we're creating non-initialized array here + auto outArr = new NDArray(out, true, ctx.launchContext(), false); ctx.pushNDArrayToVariableSpace(pair, outArr); diff --git a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp index fe0928ce6..dfc18d33b 100644 --- a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp +++ b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp @@ -31,7 +31,15 @@ namespace sd { _engine = engine; } - sd::NDArray *PlatformHelper::getZ(graph::Context &ctx, int inputId) { + sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { + auto result = getZ(block, inputId); + if (result != nullptr && !block.isInplace()) + result->nullify(); + + return result; + } + + sd::NDArray* PlatformHelper::getZ(graph::Context &ctx, int inputId) { NDArray *z = nullptr; if (ctx.isFastPath()) { diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 9d236d293..a3ea56bb6 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -540,9 +540,9 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 6c0575378..0e853865b 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -542,9 +542,9 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index ae4409923..1d365ef3a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -398,9 +398,9 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index b4df39a29..4e7a288f0 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1513,6 +1513,7 @@ #define INPUT_VARIABLE(INDEX) block.array(INDEX) #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) +#define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index b296a8903..f7e1ae7b9 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -2128,8 +2128,10 @@ TEST_F(ConvolutionTests1, col2im_test1) { auto imageExpected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f}); - LaunchContext ctx; - sd::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW); + + sd::ops::col2im op; + auto status = op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0}); + ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(image.equalsTo(imageExpected)); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 98521d58c..17bf95031 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -18,7 +18,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public IntVectorVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef IntVectorVector put(@ByRef IntVectorVector x); + public native @Name("operator =") @ByRef IntVectorVector put(@ByRef IntVectorVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -67,7 +67,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public LongVectorVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef LongVectorVector put(@ByRef LongVectorVector x); + public native @Name("operator =") @ByRef LongVectorVector put(@ByRef LongVectorVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -117,7 +117,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public NDArrayVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef NDArrayVector put(@ByRef NDArrayVector x); + public native @Name("operator =") @ByRef NDArrayVector put(@ByRef NDArrayVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -135,9 +135,9 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public Iterator(Pointer p) { super(p); } public Iterator() { } - public native @Name("operator++") @ByRef Iterator increment(); - public native @Name("operator==") boolean equals(@ByRef Iterator it); - public native @Name("operator*") @Const NDArray get(); + public native @Name("operator ++") @ByRef Iterator increment(); + public native @Name("operator ==") boolean equals(@ByRef Iterator it); + public native @Name("operator *") @Const NDArray get(); } public NDArray[] get() { @@ -185,7 +185,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public ConstNDArrayVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); + public native @Name("operator =") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -203,9 +203,9 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public Iterator(Pointer p) { super(p); } public Iterator() { } - public native @Name("operator++") @ByRef Iterator increment(); - public native @Name("operator==") boolean equals(@ByRef Iterator it); - public native @Name("operator*") @Const NDArray get(); + public native @Name("operator ++") @ByRef Iterator increment(); + public native @Name("operator ==") boolean equals(@ByRef Iterator it); + public native @Name("operator *") @Const NDArray get(); } public NDArray[] get() { @@ -250,7 +250,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); } public IntIntPair() { allocate(); } private native void allocate(); - public native @Name("operator=") @ByRef IntIntPair put(@ByRef IntIntPair x); + public native @Name("operator =") @ByRef IntIntPair put(@ByRef IntIntPair x); @MemberGetter public native int first(); public native IntIntPair first(int first); @@ -3733,16 +3733,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently */ - public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo); - public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo); - public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo); @@ -3750,16 +3750,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * set dtype as array type */ - public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); - public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); - public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); @@ -9492,6 +9492,14 @@ public static final int PREALLOC_SIZE = 33554432; * @return */ public native NDArray getZ(@ByRef Context ctx, int inputId); + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId); } @@ -10289,7 +10297,6 @@ public static final int PREALLOC_SIZE = 33554432; // #ifndef __JAVACPP_HACK__ - // #endif // JCPP // #endif // CUDA @@ -10308,6 +10315,10 @@ public static final int PREALLOC_SIZE = 33554432; public native void setDeviceID(int deviceID); public native ErrorReference errorReference(); +// #ifndef __JAVACPP_HACK__ + +// #endif + public static native @Cast("bool") boolean isInitialized(); public static native void releaseBuffers(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 240dbc843..80d5904a6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -21,7 +21,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public IntVectorVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef IntVectorVector put(@ByRef IntVectorVector x); + public native @Name("operator =") @ByRef IntVectorVector put(@ByRef IntVectorVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -70,7 +70,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public LongVectorVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef LongVectorVector put(@ByRef LongVectorVector x); + public native @Name("operator =") @ByRef LongVectorVector put(@ByRef LongVectorVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -120,7 +120,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public ConstNDArrayVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); + public native @Name("operator =") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -138,9 +138,9 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public Iterator(Pointer p) { super(p); } public Iterator() { } - public native @Name("operator++") @ByRef Iterator increment(); - public native @Name("operator==") boolean equals(@ByRef Iterator it); - public native @Name("operator*") @Const NDArray get(); + public native @Name("operator ++") @ByRef Iterator increment(); + public native @Name("operator ==") boolean equals(@ByRef Iterator it); + public native @Name("operator *") @Const NDArray get(); } public NDArray[] get() { @@ -188,7 +188,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public NDArrayVector(long n) { allocate(n); } private native void allocate(); private native void allocate(@Cast("size_t") long n); - public native @Name("operator=") @ByRef NDArrayVector put(@ByRef NDArrayVector x); + public native @Name("operator =") @ByRef NDArrayVector put(@ByRef NDArrayVector x); public boolean empty() { return size() == 0; } public native long size(); @@ -206,9 +206,9 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public Iterator(Pointer p) { super(p); } public Iterator() { } - public native @Name("operator++") @ByRef Iterator increment(); - public native @Name("operator==") boolean equals(@ByRef Iterator it); - public native @Name("operator*") @Const NDArray get(); + public native @Name("operator ++") @ByRef Iterator increment(); + public native @Name("operator ==") boolean equals(@ByRef Iterator it); + public native @Name("operator *") @Const NDArray get(); } public NDArray[] get() { @@ -253,7 +253,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); } public IntIntPair() { allocate(); } private native void allocate(); - public native @Name("operator=") @ByRef IntIntPair put(@ByRef IntIntPair x); + public native @Name("operator =") @ByRef IntIntPair put(@ByRef IntIntPair x); @MemberGetter public native int first(); public native IntIntPair first(int first); @@ -3736,16 +3736,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); /** * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently */ - public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo); - public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo); - public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo); @@ -3753,16 +3753,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * set dtype as array type */ - public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); - public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); - public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } - private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/); public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); @@ -11459,6 +11459,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #define INPUT_VARIABLE(INDEX) block.array(INDEX) // #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) +// #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) // #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) @@ -11809,6 +11810,14 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * @return */ public native NDArray getZ(@ByRef Context ctx, int inputId); + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId); } @@ -24205,6 +24214,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native void setDeviceID(int deviceID); public native ErrorReference errorReference(); +// #ifndef __JAVACPP_HACK__ + +// #endif + public static native @Cast("bool") boolean isInitialized(); public static native void releaseBuffers(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index d96c0ed31..5424d3c50 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -28,6 +28,7 @@ import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.imports.TFGraphs.NodeReader; +import org.nd4j.linalg.api.blas.BlasBufferUtil; import org.nd4j.linalg.api.blas.Level1; import org.nd4j.linalg.api.blas.params.GemmParams; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -106,6 +107,7 @@ import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Paths; import java.util.*; +import java.util.concurrent.CountDownLatch; import static org.junit.Assert.*; import static org.junit.Assert.assertArrayEquals; From 2497290cb0910f080c4acb601046025b36524e66 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 20 Mar 2020 17:25:46 +1100 Subject: [PATCH 04/17] AdaGrad validation test (#334) Signed-off-by: Alex Black --- .../linalg/api/buffer/DataBufferTests.java | 14 +++++----- .../nd4j/linalg/learning/UpdaterJavaCode.java | 10 +++++++ .../linalg/learning/UpdaterValidation.java | 26 +++++++++++++++++++ 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java index b271c7bff..b7660bc6e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java @@ -33,7 +33,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.nativeblas.NativeOpsHolder; -import sun.nio.ch.DirectBuffer; import java.nio.ByteBuffer; @@ -406,15 +405,14 @@ public class DataBufferTests extends BaseNd4jTest { //https://github.com/eclipse/deeplearning4j/issues/8783 Nd4j.create(1); - DirectBuffer bb = (DirectBuffer) ByteBuffer.allocateDirect(5); - System.out.println(bb.getClass()); - System.out.println(bb.address()); - - Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bb.address()); - DataBuffer buff = Nd4j.createBuffer(ptr, 20, DataType.BYTE); + BytePointer bp = new BytePointer(5); - INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 1L, 'c', DataType.BYTE); + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(bp.address()); + DataBuffer buff = Nd4j.createBuffer(ptr, 5, DataType.INT8); + + + INDArray arr2 = Nd4j.create(buff, new long[]{5}, new long[]{1}, 0, 'c', DataType.INT8); long before = arr2.data().pointer().address(); Nd4j.getAffinityManager().ensureLocation(arr2, AffinityManager.Location.HOST); long after = arr2.data().pointer().address(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java index c80a04c55..5e640ec8b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java @@ -25,6 +25,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Map; +import static org.nd4j.linalg.ops.transforms.Transforms.sqrt; + public class UpdaterJavaCode { private UpdaterJavaCode(){ } @@ -46,6 +48,14 @@ public class UpdaterJavaCode { msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); } + public static void applyAdaGradUpdater(INDArray gradient, INDArray state, double learningRate, double epsilon){ + state.addi(gradient.mul(gradient)); + + INDArray sqrtHistory = sqrt(state.dup('c'), false).addi(epsilon); + // lr * gradient / (sqrt(sumSquaredGradients) + epsilon) + gradient.muli(sqrtHistory.rdivi(learningRate)); + } + public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2, double epsilon, int iteration){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java index e8df8d7ae..660b178e4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -69,6 +69,32 @@ public class UpdaterValidation extends BaseNd4jTest { } } + @Test + public void testAdaGradUpdater(){ + double lr = 0.1; + double epsilon = 1e-6; + + INDArray s = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("grad", s.dup()); + AdaGradUpdater u = (AdaGradUpdater) new AdaGrad(lr, epsilon).instantiate(state, true); + + assertEquals(s, state.get("grad")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon); + + u.applyUpdater(g2, i, 0); + + assertEquals(s, state.get("grad")); + assertEquals(g1, g2); + } + } + @Test public void testAdamUpdater(){ From 5dae4069cf788a9ea1c34bc675548d2aed5ad517 Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 20 Mar 2020 10:33:20 +0200 Subject: [PATCH 05/17] Shugeo random expo fix2 (#295) * Refactored exponential distribution implementation. Signed-off-by: shugeo * Refactored exponential distribution and tests. Signed-off-by: shugeo * Refactored test to new result sets. Signed-off-by: shugeo --- libnd4j/include/ops/random_ops.h | 8 ++-- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 53 +++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index 844f88ed3..d16b4f68a 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -119,13 +119,15 @@ namespace randomOps { random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10 , sd::DataTypeUtils::template max() / 10); - return x <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow((T) M_E, -(lambda * x)); + T x = helper->relativeT(idx); //, T(0.f) , max); + T xVal = -sd::math::nd4j_log(T(1.f) - x); + + return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow((T) M_E, -(lambda * x)); } random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - return valueX <= (T)0.f ? (T)0.f : (T)1.f - sd::math::nd4j_pow((T) M_E, -(lambda * valueX)); + return valueX <= (T)0.f ? (T)0.f : (T)(valueX/lambda); //1.f - sd::math::nd4j_exp(-lambda * valueX); //pow((T) M_E, -(lambda * valueX)); } }; diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 64ab1781d..889e194a6 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -93,6 +93,21 @@ TEST_F(RNGTests, TestSeeds_2) { ASSERT_EQ(456, generator.nodeState()); } +TEST_F(RNGTests, TestGenerator_SGA_1) { + RandomGenerator generator(12, 13); + auto array= NDArrayFactory::create('c',{10000000}); + generator.setStates(123L, 456L); + for (auto idx = 0; idx < array.lengthOf(); idx++) { + float x = generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + array.t(idx) = x; + } + auto minimum = array.reduceNumber(reduce::AMin); + minimum.printBuffer("Randomly float min on 1M array"); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); +} + TEST_F(RNGTests, Test_Dropout_1) { auto x0 = NDArrayFactory::create('c', {10, 10}); @@ -573,6 +588,15 @@ TEST_F(RNGTests, Test_Uniform_2) { } +TEST_F(RNGTests, Test_Uniform_SGA_3) { + //auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max(), sd::DataTypeUtils::template max()); + auto minimumU = x1.reduceNumber(reduce::AMin); + minimumU.printBuffer("\nMinimum"); +} + TEST_F(RNGTests, Test_Gaussian_2) { auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); auto x1 = NDArrayFactory::create('c', {10, 10}); @@ -728,8 +752,37 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) { auto z = result.at(0); ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); + // + z->printBuffer("\nExponential1"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); + variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + +// delete result; +} + +TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + z->printBuffer("\nExponential2"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); ASSERT_FALSE(nexp0->equalsTo(z)); ASSERT_FALSE(nexp1->equalsTo(z)); ASSERT_FALSE(nexp2->equalsTo(z)); From e700b59f802d902cad9ade87cc516c40120fef37 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Fri, 20 Mar 2020 11:11:27 +0200 Subject: [PATCH 06/17] Shyrma weights format (#329) * - start to introduce additional weights formats into conv2d ops Signed-off-by: Yurii * - provide weights format variety in backprop conv2d and deconv2d ops, testing and fixing bugs Signed-off-by: Yurii * - forgot to recover kernels sizes in deconv2d_bp test Signed-off-by: Yurii * - built in weights format in depthwise conv 2d op Signed-off-by: Yurii * - provide new weights formats in mkl dnn conv ops Signed-off-by: Yurii * - provide new weights formats in cuda conv helpers Signed-off-by: Yurii * - working with new weights format in cudnn conv api Signed-off-by: Yurii * - take into account order of arrays in cudnn tensor descriptions Signed-off-by: Yurii * - provide new weights formats in cpu conv3d (ff/bp) Signed-off-by: Yurii * - provide new weights formats in cpu deconv3d (ff/bp) Signed-off-by: Yurii * - provide new weights formats in conv3d ops (ff/bp) based on mkl api Signed-off-by: Yurii * - provide new weights formats in conv3d ops (ff/bp) based on cudnn api Signed-off-by: Yurii * - resolve conflicts 2 Signed-off-by: Yurii Co-authored-by: raver119 --- libnd4j/include/helpers/shape.h | 2 +- .../declarable/generic/nn/convo/conv1d.cpp | 37 +- .../declarable/generic/nn/convo/conv2d.cpp | 52 +- .../declarable/generic/nn/convo/conv3d.cpp | 68 +- .../declarable/generic/nn/convo/deconv2d.cpp | 60 +- .../generic/nn/convo/deconv2d_tf.cpp | 19 +- .../declarable/generic/nn/convo/deconv3d.cpp | 61 +- .../generic/nn/convo/depthwiseConv2d.cpp | 32 +- .../generic/nn/convo/pointwiseConv2d.cpp | 18 +- .../declarable/generic/nn/convo/sconv2d.cpp | 65 +- .../generic/nn/pooling/avgpool2d.cpp | 2 +- .../generic/nn/pooling/avgpool3d.cpp | 4 +- .../generic/nn/pooling/maxpool2d.cpp | 2 +- .../generic/nn/pooling/maxpool3d.cpp | 4 +- .../generic/nn/pooling/pnormpool2d.cpp | 2 +- .../ops/declarable/helpers/convolutions.h | 67 +- .../declarable/helpers/cpu/convolutions.cpp | 126 ++- .../declarable/helpers/cuda/convolutions.cu | 117 ++- .../declarable/platform/cudnn/avgpool2d.cu | 2 +- .../declarable/platform/cudnn/avgpool3d.cu | 4 +- .../ops/declarable/platform/cudnn/conv2d.cu | 91 +- .../ops/declarable/platform/cudnn/conv3d.cu | 76 +- .../declarable/platform/cudnn/cudnnUtils.cu | 24 +- .../platform/cudnn/depthwiseConv2d.cu | 81 +- .../declarable/platform/cudnn/maxpool2d.cu | 2 +- .../declarable/platform/cudnn/maxpool3d.cu | 4 +- .../platform/mkldnn/avgpooling2d.cpp | 4 +- .../platform/mkldnn/avgpooling3d.cpp | 4 +- .../declarable/platform/mkldnn/batchnorm.cpp | 22 +- .../ops/declarable/platform/mkldnn/conv2d.cpp | 147 +-- .../ops/declarable/platform/mkldnn/conv3d.cpp | 162 ++-- .../declarable/platform/mkldnn/deconv2d.cpp | 128 +-- .../platform/mkldnn/deconv2d_tf.cpp | 27 +- .../declarable/platform/mkldnn/deconv3d.cpp | 134 +-- .../platform/mkldnn/depthwiseConv2d.cpp | 130 +-- .../declarable/platform/mkldnn/lstmLayer.cpp | 16 +- .../ops/declarable/platform/mkldnn/matmul.cpp | 4 +- .../platform/mkldnn/maxpooling2d.cpp | 4 +- .../platform/mkldnn/maxpooling3d.cpp | 4 +- .../platform/mkldnn/mkldnnUtils.cpp | 52 +- .../declarable/platform/mkldnn/mkldnnUtils.h | 6 +- .../declarable/platform/mkldnn/softmax.cpp | 16 +- .../ops/declarable/platform/mkldnn/tanh.cpp | 18 +- .../layers_tests/ConvolutionTests1.cpp | 850 +++++++++++------- .../layers_tests/ConvolutionTests2.cpp | 509 ++++++++++- 45 files changed, 2185 insertions(+), 1074 deletions(-) diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index a6b22ba6d..2c18615fc 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -4076,7 +4076,7 @@ INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShap // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code - // FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) + // FIXME - indeed we don't need to allocate so large memory amount (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) Nd4jLong tempBuffer[4*MAX_RANK]; Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides; diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index da711a569..27081b545 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -34,7 +34,7 @@ namespace ops { CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) @@ -45,12 +45,13 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC + int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] const int rank = 3; REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); - int indIOioC, indIiW, indWoC(2); + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); if(!isNCW) { indIOioC = 2; indIiW = 1; } @@ -63,7 +64,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { int iC = input->sizeAt(indIOioC); // input channels int oC = weights->sizeAt(indWoC); // output channels - std::vector expectedWeightsShape = {kW, iC, oC}; + std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -83,11 +84,11 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] sd::ops::conv2d conv2d; - const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); + const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {}); if (status != ND4J_STATUS_OK) return status; - // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW); + // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); return Status::OK(); } @@ -105,8 +106,9 @@ DECLARE_SHAPE_FN(conv1d) { int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - int indIOioC, indIiW, indWoC(2); + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); if(!isNCW) { indIOioC = 2; indIiW = 1; } @@ -123,7 +125,7 @@ DECLARE_SHAPE_FN(conv1d) { int iC = inputShapeInfo[indIOioC+1]; // input channels int oC = weightsShapeInfo[indWoC+1]; // output channels - std::vector expectedWeightsShape = {kW, iC, oC}; + std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); @@ -163,12 +165,12 @@ DECLARE_TYPES(conv1d) { CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC] always + auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width @@ -177,12 +179,14 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] const int rank = 3; REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); REQUIRE_TRUE(gradO->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf()); - int indIOioC, indIiW, indWoC(2); + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); if(!isNCW) { indIOioC = 2; indIiW = 1; } @@ -199,7 +203,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); - std::vector expectedWeightsShape = {kW, iC, oC}; + std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -222,11 +226,11 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC] sd::ops::conv2d_bp conv2dBP; - auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); + auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {}); if (status != ND4J_STATUS_OK) return status; - // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW); + // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); return Status::OK(); } @@ -235,7 +239,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { DECLARE_SHAPE_FN(conv1d_bp) { auto inputShapeInfo = inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC] always + auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next @@ -250,8 +254,9 @@ DECLARE_SHAPE_FN(conv1d_bp) { int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - int indIOioC, indIiW, indWoC(2); + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); if(!isNCW) { indIOioC = 2; indIiW = 1; } @@ -268,7 +273,7 @@ DECLARE_SHAPE_FN(conv1d_bp) { ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); - std::vector expectedWeightsShape = {kW, iC, oC}; + std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index ace83e60c..4377c1487 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -37,7 +37,7 @@ namespace ops { CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -49,21 +49,22 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); return Status::OK(); } @@ -73,7 +74,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { DECLARE_SHAPE_FN(conv2d) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always + auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] //output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -86,6 +87,7 @@ DECLARE_SHAPE_FN(conv2d) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width @@ -95,7 +97,7 @@ DECLARE_SHAPE_FN(conv2d) { REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - int indIOioC, indIiH, indWoC(3); + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0); if(!isNCHW) { indIOioC = 3; indIiH = 1; } @@ -109,7 +111,7 @@ DECLARE_SHAPE_FN(conv2d) { const int iC = inputShapeInfo[indIOioC+1]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); @@ -157,12 +159,12 @@ DECLARE_SHAPE_FN(conv2d) { CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height @@ -175,6 +177,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); @@ -182,19 +185,19 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vectorexpectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vectorexpectedWeightsShape = {kH, kW, iC, oC}; + std::vectorexpectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); return Status::OK(); } @@ -204,7 +207,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { DECLARE_SHAPE_FN(conv2d_bp) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always + auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next @@ -224,8 +227,9 @@ DECLARE_SHAPE_FN(conv2d_bp) { const int dW = INT_ARG(7); // dilations width const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - int indIOioC, indIiH, indOoH, indWoC(3); + int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; } @@ -243,7 +247,7 @@ DECLARE_SHAPE_FN(conv2d_bp) { ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) @@ -264,7 +268,7 @@ DECLARE_SHAPE_FN(conv2d_bp) { CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { auto gradIShape = INPUT_VARIABLE(0); // [4] - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon @@ -279,6 +283,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] const int rank = gradO->rankOf(); @@ -295,17 +300,17 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); return Status::OK(); } @@ -321,7 +326,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { DECLARE_SHAPE_FN(conv2d_input_bp) { auto gradIShapeShapeInfo = inputShape->at(0); // [4] - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always + auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next const int rank = 4; @@ -340,8 +345,9 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { const int dW = INT_ARG(7); // dilations width const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - int indIOioC, indIiH, indWoC(3), indOoH; + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; } @@ -361,7 +367,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 669af1940..0657f6dc2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -32,7 +32,7 @@ namespace ops { CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) @@ -52,14 +52,15 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -71,14 +72,24 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { std::vector permutForOutput; if (isNCDHW) - permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] + permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] else input = new NDArray(input->permute({0,4,1,2,3})); + std::vector wAxes; + if(0 == wFormat) + wAxes = {3,0,1,2}; + else if(1 == wFormat) + wAxes = {1,2,3,4}; + else + wAxes = {4,1,2,3}; + NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC] - MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput); + // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, oW, oC] + // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, oH, oW, oC] + MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, wAxes, permutForOutput); if(bias) // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); @@ -101,7 +112,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { DECLARE_SHAPE_FN(conv3dnew) { auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC] always + auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth @@ -118,13 +129,14 @@ DECLARE_SHAPE_FN(conv3dnew) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] const int rank = 5; REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - int indIOioC, indIiD, indWoC(4); + int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); if(!isNCDHW) { indIOioC = 4; indIiD = 1; } @@ -139,7 +151,7 @@ DECLARE_SHAPE_FN(conv3dnew) { int iC = inputShapeInfo[indIOioC+1]; // input channels int oC = weightsShapeInfo[indWoC+1]; // output channels - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); @@ -174,12 +186,12 @@ DECLARE_SHAPE_FN(conv3dnew) { CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); @@ -200,17 +212,18 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); int trueoD, trueoH, trueoW; // true output depth/height/width ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -231,10 +244,25 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW } + std::vector wPermut, colPermut; + + if(0 == wFormat) { + wPermut = {3,0,1,2,4}; + colPermut = {2,3,4,1,0,5,6,7}; + } + else if(1 == wFormat) { + wPermut = {1,2,3,4,0}; + colPermut = {1,2,3,4,0,5,6,7}; + } + else { + wPermut = {4,1,2,3,0}; + colPermut = {2,3,4,1,0,5,6,7}; + } + // ----- calculation of gradW and gradB ----- // NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] + MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] //----- calculation of gradO -----// if(gradB) { @@ -246,7 +274,10 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { } //----- calculation of gradI -----// - MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] + // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] + // [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] + // [oC, kD, kH, kW, iC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] + MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] if(!isNCDHW) { @@ -270,7 +301,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { DECLARE_SHAPE_FN(conv3dnew_bp) { Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC] always + Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next @@ -288,6 +319,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] const int rank = 5; REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); @@ -295,7 +327,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); - int indIOioC, indIiD, indWoC(4); + int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); if(!isNCDHW) { indIOioC = 4; indIiD = 1; } @@ -314,7 +346,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 12c1a9d3f..8d6c0e3a7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -35,7 +35,7 @@ namespace ops { CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -53,12 +53,13 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -66,6 +67,12 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { if(!isNCHW) output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + std::vector colPermut; + if(1 == wFormat) + colPermut = {1, 2, 3, 0, 4, 5}; + else + colPermut = {2, 3, 1, 0, 4, 5}; + if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); @@ -73,8 +80,9 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { //----- calculation of output -----// // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] - // NCHW: [kH, kW, oC, iC] x [bS, iC, iH, iW] = [kH, kW, oC, bS, iH, iW] - sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); + // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW] + // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] + sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); LaunchContext* ctx = block.launchContext(); helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW] @@ -97,7 +105,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { DECLARE_SHAPE_FN(deconv2d) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC] always + auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] const int rank = 4; @@ -114,8 +122,9 @@ DECLARE_SHAPE_FN(deconv2d) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - int indIOioC, indIiH, indWoC(2); + int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); if(!isNCHW) { indIOioC = 3; indIiH = 1; } @@ -129,7 +138,7 @@ DECLARE_SHAPE_FN(deconv2d) { const int iC = inputShapeInfo[indIOioC+1]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); @@ -163,12 +172,12 @@ DECLARE_SHAPE_FN(deconv2d) { CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); @@ -186,16 +195,17 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -206,29 +216,34 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } - - // ----- calculation of gradI -> pass it through conv2d_ff ----- // + // ----- calculation of gradI -> pass it through conv2d_ff ----- // sd::ops::conv2d conv2d; - const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW}, {}); + const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW, wFormat}, {}); if (status != ND4J_STATUS_OK) return status; // -----prepare permutation arrays and axes for dot product ----- // - std::vector inputAxesForDot; + std::vector inputAxes; if(!isNCHW) { gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - inputAxesForDot = {0, 1, 2}; // bS, iH, iW + inputAxes = {0, 1, 2}; // bS, iH, iW } else - inputAxesForDot = {0, 2, 3}; // bS, iH, iW + inputAxes = {0, 2, 3}; // bS, iH, iW + + std::vector gradWAxes; // empty for wFormat = 1 + if(0 == wFormat) + gradWAxes = {3, 2, 0, 1}; + else if(2 == wFormat) + gradWAxes = {0, 3, 1, 2}; // ----- calculation of gradW ----- // NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); LaunchContext* ctx = block.launchContext(); helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW] - MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 4, 5}, {3, 2, 0, 1}); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW] + MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW] // ----- calculation of gradB ----- // if(gradB) { @@ -248,7 +263,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { DECLARE_SHAPE_FN(deconv2d_bp) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC] always + auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next @@ -267,8 +282,9 @@ DECLARE_SHAPE_FN(deconv2d_bp) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - int indIOioC, indIiH, indWoC(2), indOoH; + int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; } @@ -286,7 +302,7 @@ DECLARE_SHAPE_FN(deconv2d_bp) { ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp index 5503019f4..ae97c3d65 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp @@ -32,10 +32,10 @@ namespace ops { CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width @@ -47,6 +47,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] const int rank = gradO->rankOf(); @@ -57,20 +58,19 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { // create empty conv2d input array NDArray input(gradO->ordering(), gradIShape->asVectorT(), gradO->dataType(), block.launchContext()); - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); return Status::OK(); } @@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { DECLARE_SHAPE_FN(deconv2d_tf) { auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC] always + auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradIShapeShapeInfo = inputShape->at(0); // [4] const int rank = 4; @@ -103,8 +103,9 @@ DECLARE_SHAPE_FN(deconv2d_tf) { const int dW = INT_ARG(7); // dilations width const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - int indIOioC, indIiH, indWoC(3), indOoH; + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; } @@ -126,7 +127,7 @@ DECLARE_SHAPE_FN(deconv2d_tf) { ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index d4899fbab..ab6e49836 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -32,7 +32,7 @@ namespace ops { CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) @@ -53,13 +53,14 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - std::vector expectedWeightsShape = {kD, kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -67,16 +68,23 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { if(!isNCDHW) output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] + std::vector colPermut; + if(1 == wFormat) + colPermut = {1,2,3,4,0,5,6,7}; + else + colPermut = {2,3,4,1,0,5,6,7}; + if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); //----- calculation of output -----// - // NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] - // NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW] - sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] - ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] + // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] + // [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, iH, iW] + // [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] + sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] + ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] //----- add biases if required -----// if(bias) @@ -101,7 +109,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { DECLARE_SHAPE_FN(deconv3d) { auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC] always + auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] const int rank = 5; @@ -122,8 +130,9 @@ DECLARE_SHAPE_FN(deconv3d) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - int indIOioC, indIiD, indWoC(3); + int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); if(!isNCDHW) { indIOioC = 4; indIiD = 1; } @@ -138,7 +147,7 @@ DECLARE_SHAPE_FN(deconv3d) { const int iC = inputShapeInfo[indIOioC+1]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels - std::vector expectedWeightsShape = {kD, kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); @@ -174,12 +183,12 @@ DECLARE_SHAPE_FN(deconv3d) { CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); @@ -201,16 +210,17 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); int trueoD, trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = {kD, kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -221,7 +231,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { // ----- calculation of gradI -> pass it through conv3d_ff ----- // sd::ops::conv3dnew conv3d; - const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW}, {}); + const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW, wFormat}, {}); if (status != ND4J_STATUS_OK) return status; @@ -235,10 +245,16 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { else inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW + std::vector gradWAxes; // empty for wFormat = 1 + if(0 == wFormat) + gradWAxes = {4,3,0,1,2}; + else if(2 == wFormat) + gradWAxes = {0,4,1,2,3}; + // ----- calculation of gradW ----- // auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] - MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] + ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] + MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] // ----- calculation of gradB ----- // if(gradB) { @@ -267,7 +283,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { DECLARE_SHAPE_FN(deconv3d_bp) { auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC] always + auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] Nd4jLong* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next @@ -290,8 +306,9 @@ DECLARE_SHAPE_FN(deconv3d_bp) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - int indIOioC, indIiD, indWoC(3); + int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); if(!isNCDHW) { indIOioC = 4; indIiD = 1; } @@ -310,8 +327,8 @@ DECLARE_SHAPE_FN(deconv3d_bp) { ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); - std::vector expectedWeightsShape = {kD, kH, kW, oC, iC}; - REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index 2bbcebb28..30580e7a6 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -32,7 +32,7 @@ namespace ops { CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) @@ -50,19 +50,20 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); return Status::OK(); } @@ -75,7 +76,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { DECLARE_SHAPE_FN(depthwise_conv2d) { Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC] always + Nd4jLong* weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC const int rank = 4; @@ -92,8 +93,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - int indIOioC, indIiH, indWmC(3); + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { indIOioC = 3; indIiH = 1; } @@ -109,7 +111,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { const int oC = iC*mC; // output channels - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); @@ -148,12 +150,12 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); @@ -170,23 +172,24 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier int trueoH, trueoW; // correct output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); return Status::OK(); } @@ -214,8 +217,9 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - int indIOioC, indIiH, indWmC(3); + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { indIOioC = 3; indIiH = 1; } @@ -234,7 +238,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if(biasShapeInfo) diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp index 02d81493a..52960c3fc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp @@ -29,7 +29,7 @@ namespace ops { CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW) @@ -47,18 +47,19 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { int pW = 0; // paddings width int dH = 1; // dilations height int dW = 1; // dilations width - int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - std::vector expectedWeightsShape = {1, 1, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW); + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW, wFormat); return Status::OK(); } @@ -73,7 +74,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(pointwise_conv2d) { Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC] always + Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] const int rank = 4; @@ -81,8 +82,9 @@ DECLARE_SHAPE_FN(pointwise_conv2d) { REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] - int indIOioC, indWoC(3); + int indIOioC, indWoC(0 == wFormat ? 3 : 0); if(!isNCHW) indIOioC = 3; else @@ -92,7 +94,7 @@ DECLARE_SHAPE_FN(pointwise_conv2d) { const int iC = inputShapeInfo[indIOioC+1]; // input channels const int oC = weightsShapeInfo[indWoC+1]; // output channels - std::vector expectedWeightsShape = {1, 1, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); if (biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index b09f29101..a804abafa 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -33,8 +33,8 @@ namespace ops { CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always - NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always + NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -66,17 +66,19 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weightsDepth->sizeAt(indWmC); // channels multiplier - std::vector expectedWeightsDShape = {kH, kW, iC, mC}; + std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); if(weightsPoint) { - std::vector expectedWeightsPShape = {1, 1, iC*mC, oC}; + std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); } if (bias) @@ -84,11 +86,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { if (iC == 1) { nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); - ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); return Status::OK(); } - ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); return Status::OK(); } @@ -103,8 +105,8 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { DECLARE_SHAPE_FN(sconv2d) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC] always - Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always + auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr if(block.width() == 3) @@ -135,8 +137,9 @@ DECLARE_SHAPE_FN(sconv2d) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - int indIOioC, indIiH, indWmC(3); + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { indIOioC = 3; indIiH = 1; } @@ -148,13 +151,13 @@ DECLARE_SHAPE_FN(sconv2d) { const int iH = inputShapeInfo[indIiH+1]; // input height const int iW = inputShapeInfo[indIiH+2]; // input width const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier - const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) + const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier + const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) - std::vector expectedWeightsDShape = {kH, kW, iC, mC}; + std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); if(weightsPShapeInfo) { - std::vector expectedWeightsPShape = {1, 1, iC*mC, oC}; + std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); } if (biasShapeInfo) @@ -195,13 +198,13 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) NDArray *gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC] always - NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always + NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always - NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always + NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] NDArray *gradB = nullptr; // [oC] if(block.width() == 4) { @@ -244,17 +247,18 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weightsDepth->sizeAt(indWmC); // channels multiplier - std::vector expectedWeightsDShape = {kH, kW, iC, mC}; + std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str()); if(weightsPoint) { - std::vector expectedWeightsPShape = {1, 1, iC*mC, oC}; + std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str()); } @@ -274,12 +278,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { auto resultFFShape = isNCHW ? std::vector({bS, mC*iC, oH, oW}) : std::vector({bS, oH, oW, mC*iC}); auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); - ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW + ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW, wFormat); // in this case oH=iH and oW=iW gradO = gradIDepth; bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step @@ -288,7 +292,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { } // ----- apply depthwise_conv2d_bp ----- // - ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); if(weightsPoint) delete gradO; @@ -301,8 +305,8 @@ DECLARE_SHAPE_FN(sconv2d_bp) { auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC] always - Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always + auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr if(block.width() == 4) { @@ -335,8 +339,9 @@ DECLARE_SHAPE_FN(sconv2d_bp) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - int indIOioC, indIiH, indWmC(3); + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { indIOioC = 3; indIiH = 1; } @@ -356,10 +361,10 @@ DECLARE_SHAPE_FN(sconv2d_bp) { std::vector expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - std::vector expectedWeightsDShape = {kH, kW, iC, mC}; + std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); if(weightsPShapeInfo) { - std::vector expectedWeightsPShape = {1, 1, iC*mC, oC}; + std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); } if (biasShapeInfo) diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index b15879df4..b93cbe47f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -166,7 +166,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index 30d03c907..85b8d8833 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); @@ -172,7 +172,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index 13d65a681..d92c27442 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -168,7 +168,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp index 37cb34cb0..3fd5f9c51 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp @@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); @@ -174,7 +174,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp index 2c5fa66c1..4c9319ca1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp @@ -167,7 +167,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index 6ba6136a4..f38692a35 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -154,15 +154,24 @@ namespace sd { } // evaluates sizes values and indexes using input and output arrays depending on data format - static inline void getSizesAndIndexesConv2d(const bool isNCHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { - getSizesAndIndexesConv2d(isNCHW, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { + getSizesAndIndexesConv2d(isNCHW, wFormat, input.getShapeInfo(), output.getShapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); } - static inline void getSizesAndIndexesConv2d(const bool isNCHW, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { + static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always + // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2) // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - indWkH = 0; indWiC = 2; indWoC = 3; + + if(0 == wFormat) { + indWkH = 0; indWiC = 2; indWoC = 3; + } + else if(1 == wFormat) { + indWkH = 2; indWiC = 1; indWoC = 0; + } + else { + indWkH = 1; indWiC = 3; indWoC = 0; + } if(!isNCHW) { indIOioC = 3; indIiH = 1; indOoH = 1; @@ -181,12 +190,21 @@ namespace sd { } // evaluates sizes values and indexes using input and output arrays depending on data format - static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { + static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { // input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - // weights [kD, kH, kW, iC, oC] (NDHWC) or [oC, iC, kD, kH, kW] (NCDHW) + // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2) // output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - indWkD = 0; indWiC = 3; indWoC = 4; + if(0 == wFormat) { + indWkD = 0; indWiC = 3; indWoC = 4; + } + else if(1 == wFormat) { + indWkD = 2; indWiC = 1; indWoC = 0; + } + else { + indWkD = 1; indWiC = 4; indWoC = 0; + } + if(!isNCDHW) { indIOioC = 4; indIOioD = 1; } @@ -203,7 +221,6 @@ namespace sd { oD = output.sizeAt(indIOioD); // output depth oH = output.sizeAt(indIOioD+1); // output height oW = output.sizeAt(indIOioD+2); // output width - } // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) { @@ -254,19 +271,41 @@ namespace sd { // } // } - static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); + static std::vector expectWeightsShape(const int wFormat, const int kH, const int kW, const int iC, const int oC) { + + if(0 == wFormat) + return std::vector({kH, kW, iC, oC}); + + if(1 == wFormat) + return std::vector({oC, iC, kH, kW}); + + return std::vector({oC, kH, kW, iC}); + } + + static std::vector expectWeightsShape(const int wFormat, const int kD, const int kH, const int kW, const int iC, const int oC) { + + if(0 == wFormat) + return std::vector({kD, kH, kW, iC, oC}); + + if(1 == wFormat) + return std::vector({oC, iC, kD, kH, kW}); + + return std::vector({oC, kD, kH, kW, iC}); + } + + static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); // static void conv2d(sd::graph::Context & block, const std::vector& inArrs, NDArray* output, const std::vector& intArgs); // static void conv2dBP(sd::graph::Context & block, const std::vector& inArrs, const std::vector& outArrs, const std::vector& intArgs); - static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); + static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); + static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); + static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); + static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index f852bed23..4140c2143 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -258,10 +258,10 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// template - static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] // bias [oC] // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -278,7 +278,7 @@ namespace sd { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); @@ -291,6 +291,14 @@ namespace sd { else input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + std::vector wAxes; + if(0 == wFormat) + wAxes = {0, 1, 2}; + else if(1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); @@ -298,7 +306,7 @@ namespace sd { //----- calculation of output -----// auto ctx = block.launchContext(); helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] //----- assign outTemp to output -----// if(isNCHW) { @@ -319,15 +327,15 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// template - static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] // bias [oC] // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC] always + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] // gradB [oC] // kH filter(kernel) height @@ -343,7 +351,7 @@ namespace sd { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); @@ -359,13 +367,28 @@ namespace sd { gradOaxesForDot = {0, 2, 3}; // bS, oH, oW } + std::vector wPermut, colPermut; + + if(0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + else if(1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } + else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); // ----- calculation of gradW ----- // if(gradW) { auto ctx = block.launchContext(); helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] + sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] } // ----- calculation of gradB ----- // @@ -379,9 +402,12 @@ namespace sd { } //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] + // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] if(!isNCHW) { delete input; @@ -391,10 +417,10 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// template - static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC] always + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] // bias [oC] = iC*mC // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) @@ -411,23 +437,30 @@ namespace sd { int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput; + std::vector> modifOutput, modifWeights; std::vector outReShape; if(!isNCHW) { outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] } else { outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] } + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); @@ -435,7 +468,7 @@ namespace sd { NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] if(bias) // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); @@ -447,14 +480,14 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// template - static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC] always + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] // bias [oC] = [iC*mC] // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC] always + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] // gradB [oC] // kH filter(kernel) height @@ -470,19 +503,19 @@ namespace sd { int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2; + std::vector> modifGradO1, modifGradO2, modifWeights; std::vector gradOreShape; if(!isNCHW) { gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] } else { gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] @@ -490,6 +523,13 @@ namespace sd { modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] } + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); @@ -499,7 +539,7 @@ namespace sd { // ----- calculation of gradW and gradB ----- // helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] // ----- calculation of gradB ----- // if(gradB) { @@ -513,8 +553,8 @@ namespace sd { } //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] if(!isNCHW) { delete input; @@ -524,11 +564,11 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// template - static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC] always - // weightsPoint [1, 1, iC*mC, oC] always + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] // bias [oC], oC = iC*mC if weightsPoint=nullptr // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -545,7 +585,7 @@ namespace sd { int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weightsDepth->sizeAt(indWmC); // channels multiplier NDArray* outputDepth = output; @@ -553,11 +593,11 @@ namespace sd { outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW delete outputDepth; } } @@ -1772,20 +1812,20 @@ namespace sd { - void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } - void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } - void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } - void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } - void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 76ba2e1df..47da861ed 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -217,10 +217,10 @@ void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, ND ////////////////////////////////////////////////////////////////////////// template -static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { +static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] // bias [oC] // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -237,7 +237,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); @@ -248,6 +248,14 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr else input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + std::vector wAxes; + if(0 == wFormat) + wAxes = {0, 1, 2}; + else if(1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); @@ -255,7 +263,7 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr //----- calculation of output -----// auto ctx = block.launchContext(); helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] //----- assign outTemp to output -----// if(isNCHW) { @@ -275,16 +283,16 @@ static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArr } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC] always + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] // bias [oC] = iC*mC // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) @@ -301,23 +309,30 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput; + std::vector> modifOutput, modifWeights; std::vector outReShape; if(!isNCHW) { outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] } else { outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] } + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); @@ -325,7 +340,7 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] if(bias) // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); @@ -336,17 +351,17 @@ static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, co } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { +static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC] always - // weightsPoint [1, 1, iC*mC, oC] always + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] // bias [oC], oC = iC*mC if weightsPoint=nullptr // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -363,7 +378,7 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weightsDepth->sizeAt(indWmC); // channels multiplier NDArray* outputDepth = output; @@ -371,18 +386,18 @@ static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDAr outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW delete outputDepth; } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -1176,15 +1191,15 @@ void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& inp ////////////////////////////////////////////////////////////////////////// template -static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] // bias [oC] // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC] always + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] // gradB [oC] // kH filter(kernel) height @@ -1200,7 +1215,7 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); @@ -1214,13 +1229,27 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA gradOaxesForDot = {0, 2, 3}; // bS, oH, oW } + std::vector wPermut, colPermut; + if(0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + else if(1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } + else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); // ----- calculation of gradW ----- // if(gradW) { auto ctx = block.launchContext(); helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] + sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] } // ----- calculation of gradB ----- // @@ -1234,7 +1263,10 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA } //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] + // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] @@ -1245,20 +1277,20 @@ static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDA } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC] always + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] // bias [oC] = [iC*mC] // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC] always + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] // gradB [oC] // kH filter(kernel) height @@ -1274,11 +1306,11 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2; + std::vector> modifGradO1, modifGradO2, modifWeights; std::vector gradOreShape; if(!isNCHW) { @@ -1294,6 +1326,13 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] } + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); @@ -1303,7 +1342,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con // ----- calculation of gradW and gradB ----- // helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] // ----- calculation of gradB ----- // if(gradB) { @@ -1316,7 +1355,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con } //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] if(!isNCHW) { @@ -1326,8 +1365,8 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu index 015c08172..5cf93f10f 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu @@ -102,7 +102,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu index aeaaa6516..0d01dfef3 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu @@ -54,7 +54,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); @@ -108,7 +108,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu index b58cc40f3..43dc7ce07 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu @@ -34,22 +34,25 @@ static void conv2dCUDNN(const LaunchContext* context, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW) { + const int paddingMode, const bool isNCHW, const int wFormat) { + + // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); // input descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); @@ -58,13 +61,13 @@ static void conv2dCUDNN(const LaunchContext* context, // weights descriptor cudnnFilterDescriptor_t w; cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), formatW, oC, iC, kH, kW); if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); // output descriptor cudnnTensorDescriptor_t z; cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) + if(output->ews() == 1 && output->ordering() == 'c') err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); else err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); @@ -104,10 +107,10 @@ static void conv2dCUDNN(const LaunchContext* context, // add bias if it is present if (bias != nullptr) { - cudnnTensorDescriptor_t b; cudnnCreateTensorDescriptor(&b); - err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); + // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); + err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); @@ -131,22 +134,23 @@ static void conv2dBpCUDNN(const LaunchContext* context, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW) { + const int paddingMode, const bool isNCHW, const int wFormat) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); // input descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); @@ -155,7 +159,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, // gradO descriptor cudnnTensorDescriptor_t dz; cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) + if(gradO->ews() == 1 && gradO->ordering() == 'c') err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); else err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); @@ -164,7 +168,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, // gradI descriptor cudnnTensorDescriptor_t dx; cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1) + if(gradI->ews() == 1 && gradI->ordering() == 'c') err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); @@ -173,7 +177,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, // gradW descriptor cudnnFilterDescriptor_t dw; cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, oC, iC, kH, kW); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, oC, iC, kH, kW); if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); // description of convolution @@ -220,7 +224,8 @@ static void conv2dBpCUDNN(const LaunchContext* context, if(gradB != nullptr) { cudnnTensorDescriptor_t db; cudnnCreateTensorDescriptor(&db); - err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); + // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); + err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); @@ -251,7 +256,7 @@ static void conv2dBpCUDNN(const LaunchContext* context, PLATFORM_IMPL(conv2d, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -263,7 +268,8 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width @@ -273,31 +279,35 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) { REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); } - NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} - newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW) + NDArray* newWeights = weights; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} + if(0 == wFormat) { + newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(isNCHW ? std::vector({3,2,0,1}) : std::vector({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC) + } NDArray* newInput = input; NDArray* newGradI = nullptr; if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); + conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW, wFormat); if(newInput != input) delete newInput; - delete newWeights; + if(0 == wFormat) + delete newWeights; return Status::OK(); } @@ -322,12 +332,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CUDA) { PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height @@ -340,6 +350,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); @@ -347,7 +358,7 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); @@ -355,26 +366,30 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} - NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kH, kW}, weights->dataType(), weights->getContext()); - - newWeights->assign(weights->permute({3,2,0,1})); // permute weights (kH, kW, iC, oC --> oC, iC, kH, kW) + NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} + if(0 == wFormat) { + newGradW = new NDArray(gradW->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), gradW->dataType(), gradW->getContext()); + newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(isNCHW ? std::vector({3,2,0,1}) : std::vector({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC) + } NDArray* newInput = input; NDArray* newGradI = gradI; if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); + conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW,wFormat); - newGradW->permutei({2,3,1,0}); // [oC, iC, kH, kW] -> [kH, kW, iC, oC] - gradW->assign(newGradW); + if(0 == wFormat) { + newGradW->permutei(isNCHW ? std::vector({2,3,1,0}) : std::vector({1,2,3,0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or (oC, kH, kW, iC --> kH, kW, iC, oC) + gradW->assign(newGradW); + } if(newInput != input) { @@ -387,8 +402,10 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { delete newGradI; } - delete newWeights; - delete newGradW; + if(0 == wFormat) { + delete newWeights; + delete newGradW; + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu index 1e86aaa07..9d226d6f7 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu @@ -34,13 +34,15 @@ static void conv3dCUDNN(const LaunchContext* context, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int paddingMode, const bool isNCDHW) { + const int paddingMode, const bool isNCDHW, const int wFormat) { + + // cudnn support only one format for weights {oC,iC,kD,kH,kW} const int numDims = 5; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); @@ -53,7 +55,7 @@ static void conv3dCUDNN(const LaunchContext* context, const std::vector xShape = {bS, iC, iD, iH, iW}; const std::vector zShape = {bS, oC, oD, oH, oW}; const std::vector wShape = {oC, iC, kD, kH, kW}; - const std::vector bShape = {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; + const std::vector bShape = {1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; const std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; @@ -120,7 +122,7 @@ static void conv3dCUDNN(const LaunchContext* context, cudnnTensorDescriptor_t b; cudnnCreateTensorDescriptor(&b); - err = cudnnSetTensorNdDescriptorEx(b, format, cudnnDataType(bias->dataType()), numDims, bShape.data()); + err = cudnnSetTensorNdDescriptorEx(b, /*format*/CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), numDims, bShape.data()); if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); @@ -144,13 +146,15 @@ static void conv3dBpCUDNN(const LaunchContext* context, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int paddingMode, const bool isNCDHW) { + const int paddingMode, const bool isNCDHW, const int wFormat) { + + // cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for weights/gradW const int numDims = 5; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); @@ -170,6 +174,7 @@ static void conv3dBpCUDNN(const LaunchContext* context, const std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); // input descriptor cudnnTensorDescriptor_t x; @@ -201,7 +206,7 @@ static void conv3dBpCUDNN(const LaunchContext* context, // gradW descriptor cudnnFilterDescriptor_t dw; cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); + err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, numDims, wShape.data()); if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); // description of convolution @@ -280,7 +285,7 @@ static void conv3dBpCUDNN(const LaunchContext* context, PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) @@ -301,34 +306,39 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} - newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) + NDArray* newWeights = weights; // cudnn support only one format {oC,iC,kD,kH,kW} + if(1 != wFormat) { + newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(0 == wFormat ? std::vector({4,3,0,1,2}) : std::vector({0,4,1,2,3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW + } NDArray* newInput = input; NDArray* newGradI = nullptr; if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); - conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW); + conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW, wFormat); if(newInput != input) delete newInput; - delete newWeights; + if(1 != wFormat) + delete newWeights; return Status::OK(); } @@ -337,7 +347,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID @@ -353,12 +363,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); @@ -379,10 +389,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); int trueoD, trueoH, trueoW; // true output depth/height/width ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); @@ -390,7 +401,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -398,20 +409,25 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - NDArray* newGradW = new NDArray(gradW->ordering(), {oC, iC, kD, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} - NDArray* newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); - - newWeights->assign(weights->permute({4,3,0,1,2})); // permute weights (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) + NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} + if(0 == wFormat) { + newGradW = new NDArray(gradW->ordering(), isNCDHW ? std::vector({oC, iC, kD, kH, kW}) : std::vector({oC, kD, kH, kW, iC}), gradW->dataType(), gradW->getContext()); + newWeights = new NDArray(weights->ordering(), isNCDHW ? std::vector({oC, iC, kD, kH, kW}) : std::vector({oC, kD, kH, kW, iC}), weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(isNCDHW ? std::vector({4,3,0,1,2}) : std::vector({4,0,1,2,3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC) + } NDArray* newInput = input; NDArray* newGradI = gradI; if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); - conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW); + conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW,wFormat); + + if(0 == wFormat) { + newGradW->permutei(isNCDHW ? std::vector({2,3,4,1,0}) : std::vector({1,2,3,4,0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC) + gradW->assign(newGradW); + } - newGradW->permutei({2,3,4,1,0}); // [oC, iC, kD, kH, kW] -> [kD, kH, kW, iC, oC] - gradW->assign(newGradW); if(newInput != input) { @@ -424,8 +440,10 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { delete newGradI; } - delete newWeights; - delete newGradW; + if(0 == wFormat) { + delete newWeights; + delete newGradW; + } return Status::OK(); } @@ -433,7 +451,7 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu index 22b0f9b1c..28e845b00 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu @@ -124,7 +124,7 @@ void pooling2dCUDNN(const LaunchContext* context, int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); @@ -135,7 +135,7 @@ void pooling2dCUDNN(const LaunchContext* context, // input descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); @@ -144,7 +144,7 @@ void pooling2dCUDNN(const LaunchContext* context, // output descriptor cudnnTensorDescriptor_t z; cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) + if(output->ews() == 1 && output->ordering() == 'c') err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); else err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); @@ -187,7 +187,7 @@ void pooling2dBpCUDNN(const LaunchContext* context, int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); @@ -198,7 +198,7 @@ void pooling2dBpCUDNN(const LaunchContext* context, // input and gradI descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); @@ -207,7 +207,7 @@ void pooling2dBpCUDNN(const LaunchContext* context, // gradO descriptor cudnnTensorDescriptor_t dz; cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) + if(gradO->ews() == 1 && gradO->ordering() == 'c') err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); else err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); @@ -255,7 +255,7 @@ void pooling3dCUDNN(const LaunchContext* context, int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); const int pSizes[] = {pD, pH, pW}; const int sSizes[] = {sD, sH, sW}; @@ -272,7 +272,7 @@ void pooling3dCUDNN(const LaunchContext* context, // input descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); else err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); @@ -281,7 +281,7 @@ void pooling3dCUDNN(const LaunchContext* context, // output descriptor cudnnTensorDescriptor_t z; cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) + if(output->ews() == 1 && output->ordering() == 'c') err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); else err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); @@ -330,7 +330,7 @@ void pooling3dBpCUDNN(const LaunchContext* context, int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); const int pSizes[] = {pD, pH, pW}; const int sSizes[] = {sD, sH, sW}; @@ -347,7 +347,7 @@ void pooling3dBpCUDNN(const LaunchContext* context, // input and gradI descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); else err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); @@ -356,7 +356,7 @@ void pooling3dBpCUDNN(const LaunchContext* context, // gradO descriptor cudnnTensorDescriptor_t dz; cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) + if(gradO->ews() == 1 && gradO->ordering() == 'c') err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); else err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); diff --git a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu index ae07ce944..612206f35 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu @@ -39,14 +39,14 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc - // weights [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute + // weights [iC, mC, kH, kW] // bias [oC], may be nullptr // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // oC = iC*mC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(1); auto handle = reinterpret_cast(context->getCuDnnHandle()); @@ -58,7 +58,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, // input descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); @@ -73,7 +73,7 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, // output descriptor cudnnTensorDescriptor_t z; cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) + if(output->ews() == 1 && output->ordering() == 'c') err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); else err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); @@ -117,7 +117,8 @@ static void depthwiseConv2dCUDNN(const LaunchContext* context, cudnnTensorDescriptor_t b; cudnnCreateTensorDescriptor(&b); - err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); + // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); + err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); err = cudnnAddTensor(*handle, alpha, b, bias->getSpecialBuffer(), alpha, z, output->specialBuffer()); if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); @@ -146,14 +147,14 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc - // weights, gradW [iC, mC, kH, kW], mkl doesn't support this format, so we'll make permute + // weights, gradW [iC, mC, kH, kW] // gradB [oC], may be nullptr // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // oC = iC*mC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(1); auto handle = reinterpret_cast(context->getCuDnnHandle()); @@ -165,7 +166,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, // input descriptor cudnnTensorDescriptor_t x; cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) + if(input->ews() == 1 && input->ordering() == 'c') err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); @@ -174,7 +175,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, // gradO descriptor cudnnTensorDescriptor_t dz; cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) + if(gradO->ews() == 1 && gradO->ordering() == 'c') err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); else err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); @@ -183,7 +184,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, // gradI descriptor cudnnTensorDescriptor_t dx; cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1) + if(gradI->ews() == 1 && gradI->ordering() == 'c') err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); else err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); @@ -241,7 +242,8 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, if(gradB != nullptr) { cudnnTensorDescriptor_t db; cudnnCreateTensorDescriptor(&db); - err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); + // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); + err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->getSpecialBuffer(), beta, db, gradB->getSpecialBuffer()); @@ -272,7 +274,7 @@ static void depthwiseConv2dBpCUDNN(const LaunchContext* context, PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) @@ -290,22 +292,31 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} - newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW) + std::vector wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} in our case + if(0 == wFormat) + wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW + else if(1 == wFormat) + wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW + else + wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW + + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(wPermut)); NDArray* newInput = input; NDArray* newGradI = nullptr; @@ -326,12 +337,13 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - const int mC = weights->sizeAt(3); + const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; @@ -344,12 +356,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); @@ -366,10 +378,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier int trueoH, trueoW; // correct output height, width @@ -378,17 +391,30 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + std::vector wPermut, gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} + if(0 == wFormat) { + wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW + gradWPermut = {2,3,0,1}; // iC, mC, kH, kW -> kH, kW, iC, mC + } + else if(1 == wFormat) { + wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW + gradWPermut = {1,0,2,3}; // iC, mC, kH, kW -> mC, iC, kH, kW + } + else { + wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW + gradWPermut = {1,2,3,0}; // iC, mC, kH, kW -> mC, kH, kW, iC + } - NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); // cudnn support format {oC, iC/groupCount, kH, kW} + NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute({2,3,0,1})); // assign permuted weights (kH, kW, iC, mC --> iC, mC, kH, kW) + newWeights->assign(weights->permute(wPermut)); NDArray* newInput = input; NDArray* newGradI = gradI; @@ -397,7 +423,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); - newGradW->permutei({2,3,0,1}); // [iC, mC, kH, kW] -> [kH, kW, iC, mC] + newGradW->permutei(gradWPermut); gradW->assign(newGradW); if(newInput != input) { @@ -420,14 +446,15 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - const int mC = weights->sizeAt(3); + const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu index 841faa0d3..3919d9614 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu @@ -98,7 +98,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu index 82e7b9f84..d28541b08 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu @@ -54,7 +54,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); @@ -106,7 +106,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 9df7bedf3..4adab2dfe 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -60,7 +60,7 @@ PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); if (paddingMode) ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); @@ -105,7 +105,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index e8582658e..96110bd29 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -61,7 +61,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); if(paddingMode) // SAME ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); @@ -109,7 +109,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index cc52e90b3..173880e63 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -91,12 +91,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + mkldnnUtils::setBlockStrides(x, x_user_md); // z, output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(z, xRank, z_user_md); + mkldnnUtils::setBlockStrides(z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -112,9 +112,9 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); - - // z + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // z auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; @@ -207,19 +207,19 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(x, xRank, x_user_md); - + mkldnnUtils::setBlockStrides(x, x_user_md); + // dLdO dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(dLdO, xRank, dLdO_user_md); + mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); // dLdI dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(dLdI, xRank, dLdI_user_md); + mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -239,10 +239,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_bp_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // dLdO - mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, args, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // mean auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index a3ea56bb6..0aa05f7f2 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -38,13 +38,13 @@ namespace platforms { static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW) { + const int paddingMode, const int isNCHW, const int wFormat) { - // weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW] + // mkl support weights in [oC, iC, kH, kW] format only int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d @@ -53,8 +53,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims dilation = { dH-1, dW-1}; - auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kH, kW}; @@ -66,17 +66,29 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(input, 4, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + } + else if(1 == wFormat) { + i0 = 0; i1 = 1; i2 = 2; i3 = 3; + } + else { + i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] + } + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + } // bias dnnl::memory::desc b_mkl_md; @@ -85,9 +97,8 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - - mkldnnUtils::setBlockStrides(output, 4, z_user_md); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -103,10 +114,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -135,13 +146,13 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, NDArray *gradI, NDArray *gradW, NDArray *gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW) { + const int paddingMode, const int isNCHW, const int wFormat) { - // weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW] + // mkl support weights/gradW in [oC, iC, kH, kW] format only int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d @@ -150,8 +161,8 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims dilation = { dH-1, dW-1}; - auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kH, kW}; @@ -163,36 +174,60 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(input, 4, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + } + else if(1 == wFormat) { + i0 = 0; i1 = 1; i2 = 2; i3 = 3; + } + else { + i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] + } + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + } // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); - + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); - + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + } + else if(1 == wFormat) { + i0 = 0; i1 = 1; i2 = 2; i3 = 3; + } + else { + i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] + } + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + } // gradB dnnl::memory::desc gradB_mkl_md; @@ -221,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); @@ -489,7 +524,7 @@ static void conv2dBpMKLDNN(sd::graph::Context &block, PLATFORM_IMPL(conv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -500,24 +535,25 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } @@ -536,12 +572,12 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) { PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height @@ -554,10 +590,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); @@ -566,13 +603,13 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 0e853865b..68f0eea89 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -40,13 +40,13 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW) { + const int paddingMode, const int isNCDHW, const int wFormat) { - // weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW] + // mkl support weights in [oC, iC, kD, kH, kW] format only int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d @@ -56,8 +56,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; - auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; + auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; @@ -69,18 +69,30 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(input, 5, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3, i4; + if(0 == wFormat) { + i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + } + else if(1 == wFormat) { + i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; + } + else { + i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] + } + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + } // bias dnnl::memory::desc b_mkl_md; @@ -89,8 +101,8 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(output, 5, z_user_md); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -106,11 +118,11 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); - + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + // bias if(bias != nullptr) { auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); @@ -140,13 +152,13 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW) { + const int paddingMode, const int isNCDHW, const int wFormat) { - // weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW] + // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d @@ -156,8 +168,8 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; - auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; + auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; @@ -169,40 +181,64 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - mkldnnUtils::setBlockStrides(input, 5, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3, i4; + if(0 == wFormat) { + i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + } + else if(1 == wFormat) { + i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; + } + else { + i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] + } + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + } // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3, i4; + if(0 == wFormat) { + i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + } + else if(1 == wFormat) { + i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; + } + else { + i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] + } + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); + } // gradB dnnl::memory::desc gradB_mkl_md; @@ -231,10 +267,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); @@ -486,7 +522,7 @@ static void conv3dBpMKLDNN(sd::graph::Context &block, PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) @@ -507,12 +543,13 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -520,7 +557,7 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { if (paddingMode) // SAME ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); + conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); return Status::OK(); } @@ -538,12 +575,12 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC] always + auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); @@ -564,10 +601,11 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); if(paddingMode) // SAME ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); @@ -576,26 +614,26 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); - std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); + conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); return Status::OK(); } PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 1ee177e6a..a1ca2a717 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -34,19 +34,30 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW) { + const int paddingMode, const bool isNCHW, const int wFormat) { - // weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation + // mkl supports weights format [oC, iC, kH, kW] only int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dH-1, dW-1 }; + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + } + else if(1 == wFormat) { + i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + } + else { + i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] + } + // input type dnnl::memory::data_type xType; if(input->dataType() == DataType::FLOAT32) @@ -76,8 +87,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N else zType = dnnl::memory::data_type::s32; - dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kH, kW}; @@ -87,17 +98,17 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - mkldnnUtils::setBlockStrides(input, 4, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); // bias dnnl::memory::desc b_mkl_md; @@ -106,8 +117,8 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); - mkldnnUtils::setBlockStrides(output, 4, z_user_md); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -124,10 +135,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -156,19 +167,30 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N ////////////////////////////////////////////////////////////////////////// static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW) { + const int paddingMode, const bool isNCHW, const int wFormat) { - // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation + // mkl supports weights/gradW in [oC, iC, kH, kW] format only int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dH-1, dW-1 }; + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + } + else if(1 == wFormat) { + i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + } + else { + i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] + } + // input type dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // weights type @@ -182,8 +204,8 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // gradB type dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kH, kW}; @@ -193,36 +215,36 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - mkldnnUtils::setBlockStrides(input, 4, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); // gradB dnnl::memory::desc gradB_mkl_md; @@ -251,10 +273,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); @@ -311,7 +333,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const PLATFORM_IMPL(deconv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) @@ -327,14 +349,15 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -344,7 +367,7 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } - deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } @@ -377,12 +400,12 @@ PLATFORM_CHECK(deconv2d, ENGINE_CPU) { PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); @@ -398,18 +421,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); int trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -420,19 +444,19 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } - deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] int dH = INT_ARG(6); // dilations height diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index e7283e1d3..3236990b1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -34,7 +34,7 @@ namespace platforms { static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const bool isNCHW) { + const bool isNCHW, const int wFormat) { // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] @@ -52,8 +52,8 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // gradI type dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kH, kW}; @@ -66,7 +66,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); @@ -75,13 +75,13 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -101,10 +101,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // provide memory buffers and check whether reorder is required // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO - mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); @@ -128,10 +128,10 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width @@ -143,6 +143,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] const int rank = gradO->rankOf(); @@ -188,7 +189,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] // } - deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); + deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); // delete weights; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index dc50288a0..bcc3d700a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -35,19 +35,30 @@ namespace platforms { static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const bool isNCDHW) { + const bool isNCDHW, const int wFormat) { - // weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation + // mkl supports weights in [oC, iC, kD, kH, kW] only int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; + uint i0, i1, i2, i3, i4; + if(0 == wFormat) { + i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + } + else if(1 == wFormat) { + i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + } + else { + i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] + } + // input type dnnl::memory::data_type xType; if(input->dataType() == DataType::FLOAT32) @@ -77,8 +88,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N else zType = dnnl::memory::data_type::s32; - dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; + dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; @@ -88,18 +99,18 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - mkldnnUtils::setBlockStrides(input, 5, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); // bias dnnl::memory::desc b_mkl_md; @@ -108,8 +119,8 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); - mkldnnUtils::setBlockStrides(output, 5, z_user_md); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -126,10 +137,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -161,19 +172,30 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const bool isNCDHW) { + const bool isNCDHW, const int wFormat) { - // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation + // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; + uint i0, i1, i2, i3, i4; + if(0 == wFormat) { + i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + } + else if(1 == wFormat) { + i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + } + else { + i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] + } + // input type dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // weights type @@ -187,8 +209,8 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // gradB type dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; + dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; @@ -198,38 +220,38 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - mkldnnUtils::setBlockStrides(input, 5, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - mkldnnUtils::setBlockStrides(gradO, 5, gradO_user_md); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - mkldnnUtils::setBlockStrides(gradI, 5, gradI_user_md); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); + dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); // gradB dnnl::memory::desc gradB_mkl_md; @@ -259,10 +281,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); @@ -319,7 +341,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, PLATFORM_IMPL(deconv3d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) @@ -341,12 +363,13 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - std::vector expectedWeightsShape = {kD, kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); @@ -356,7 +379,7 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); } - deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); + deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); return Status::OK(); } @@ -390,12 +413,12 @@ PLATFORM_CHECK(deconv3d, ENGINE_CPU) { PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); @@ -416,17 +439,18 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); int trueoD, trueoH, trueoW; // true output height, width ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = {kD, kH, kW, oC, iC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) @@ -435,7 +459,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); + deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); return Status::OK(); } @@ -443,12 +467,12 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] int dD = INT_ARG(9); // dilations depth diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index 1d365ef3a..2ca16bb8e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -35,19 +35,19 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW) { + const int paddingMode, const bool isNCHW, const int wFormat) { // mkl supports only following case: mC = 1, oC = iC // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given - // weights [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute + // weights {iC, mC, 1, kH, kW} // bias [oC], may be nullptr // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // oC = iC*mC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d @@ -57,6 +57,17 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims dilation = { dH-1, dW-1}; + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] + } + else if(1 == wFormat) { + i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] + } + else { + i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] + } + // input type dnnl::memory::data_type xType; if(input->dataType() == DataType::FLOAT32) @@ -86,8 +97,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, else zType = dnnl::memory::data_type::s32; - dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; + dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; @@ -97,18 +108,18 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); - mkldnnUtils::setBlockStrides(input, 4, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); - // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; + // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); w_user_md.data.format_desc.blocking.strides[2] = 0; - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); // bias dnnl::memory::desc b_mkl_md; @@ -117,8 +128,8 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); - mkldnnUtils::setBlockStrides(output, 4, z_user_md); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -135,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -166,19 +177,19 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, ////////////////////////////////////////////////////////////////////////// static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW) { + const int paddingMode, const bool isNCHW, const int wFormat) { // mkl supports only following case: mC = 1, oC = iC // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given - // weights, gradW [kH, kW, iC, mC], mkl doesn't support this format, so we'll make permute + // weights/gradW {iC, mC, 1, kH, kW} // gradB [oC], may be nullptr // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc // oC = iC*mC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d @@ -188,6 +199,17 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; dnnl::memory::dims dilation = { dH-1, dW-1}; + uint i0, i1, i2, i3; + if(0 == wFormat) { + i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] + } + else if(1 == wFormat) { + i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] + } + else { + i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] + } + // input type dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // weights type @@ -201,8 +223,8 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // gradB type dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; + dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; @@ -212,38 +234,38 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); - mkldnnUtils::setBlockStrides(input, 4, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); - // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; + // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // permute - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); w_user_md.data.format_desc.blocking.strides[2] = 0; - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat); - mkldnnUtils::setBlockStrides(gradO, 4, gradO_user_md); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat); - mkldnnUtils::setBlockStrides(gradI, 4, gradI_user_md); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - // gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; + // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // permute - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); // permute + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); gradW_user_md.data.format_desc.blocking.strides[2] = 0; - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(0); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(1); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3); // gradB dnnl::memory::desc gradB_mkl_md; @@ -272,10 +294,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_weights_bp_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, args, w_user_md, op_data_bp_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); @@ -332,7 +354,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) @@ -347,21 +369,22 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } @@ -394,12 +417,12 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always + auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); @@ -416,10 +439,11 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); mC = weights->sizeAt(indWmC); // channels multiplier int trueoH, trueoW; // correct output height, width @@ -428,13 +452,13 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, mC}; + std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } @@ -443,12 +467,12 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] const DataType xType = input->dataType(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index c4d987054..94c795401 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -272,14 +272,14 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, lstm_prim_desc.src_layer_desc(), DNNL_ARG_SRC_LAYER); - + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); + // wx - mkldnnUtils::loadDataToMklStream(Wx, engine, stream, args, wx_user_md, lstm_prim_desc.weights_layer_desc(), DNNL_ARG_WEIGHTS_LAYER); + mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); // wr - mkldnnUtils::loadDataToMklStream(Wr, engine, stream, args, wr_user_md, lstm_prim_desc.weights_iter_desc(), DNNL_ARG_WEIGHTS_ITER); - + mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); + // h auto h_user_mem = dnnl::memory(h_user_md, engine, h->getBuffer()); const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); @@ -288,17 +288,17 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // b if(b) { - mkldnnUtils::loadDataToMklStream(b, engine, stream, args, b_user_md, lstm_prim_desc.bias_desc(), DNNL_ARG_BIAS); + mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); } // hI if(hI) { - mkldnnUtils::loadDataToMklStream(hI, engine, stream, args, hI_user_md, lstm_prim_desc.src_iter_desc(), DNNL_ARG_SRC_ITER); + mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); } // cI if(cI) { - mkldnnUtils::loadDataToMklStream(cI, engine, stream, args, cI_user_md, lstm_prim_desc.src_iter_c_desc(), DNNL_ARG_SRC_ITER_C); + mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); } bool hLReorder(false), cLReorder(false); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 805507277..91e56d801 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -163,7 +163,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(xTR, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); /* auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); @@ -173,7 +173,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b args[DNNL_ARG_SRC] = x_mkl_mem; */ // y - mkldnnUtils::loadDataToMklStream(yTR, engine, stream, args, y_user_md, op_prim_desc.weights_desc(), DNNL_ARG_WEIGHTS); + mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); /* auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 1b60684a1..50b3fafa5 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); if (paddingMode) ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); @@ -102,7 +102,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index fbd17d882..078b45ba0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -60,7 +60,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); if(paddingMode) // SAME ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); @@ -107,7 +107,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 1c6974ea8..b8e489c4c 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -30,7 +30,7 @@ namespace mkldnnUtils { ////////////////////////////////////////////////////////////////////// void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ - + std::vector vDims(rank); for (auto i = 0; i < rank; i++) { vDims[i] = array->sizeAt(i); @@ -56,25 +56,27 @@ dnnl::memory::format_tag getFormat(const int rank){ } return dnnl::memory::format_tag::a; // 1 == dataSetRank } + ////////////////////////////////////////////////////////////////////// -void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd){ - if (array->ews() != 1 || array->ordering() != 'c') { - mklMd.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < rank; ++i) { - mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); - } +void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){ + + if (array->ews() != 1 || array->ordering() != 'c') { + mklMd.data.format_kind = dnnl_blocked; // overrides format + for (auto i = 0; i < array->rankOf(); ++i) { + mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); } + } } //////////////////////////////////////////////////////////////////////////////////////////////// -void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, - std::unordered_map& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG ){ - - auto user_mem = dnnl::memory(user_md, engine, array->getBuffer()); - const bool bReorder = primitive_md != user_mem.get_desc(); - auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; - if (bReorder) - dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); - args[DNNL_ARG] = mkl_mem; +void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, + dnnl::memory& arg) { + + auto user_mem = dnnl::memory(user_md, engine, array->getBuffer()); + const bool bReorder = primitive_md != user_mem.get_desc(); + auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; + if (bReorder) + dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); + arg = mkl_mem; } ////////////////////////////////////////////////////////////////////// @@ -95,7 +97,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, if(rank == 4) { // 2d - ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); strides = { sH, sW }; kernel = { kH, kW }; @@ -108,7 +110,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, } else { // 3d - ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); + ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); strides = { sD, sH, sW }; kernel = { kD, kH, kW }; @@ -162,7 +164,7 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // output auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); @@ -199,7 +201,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, if(rank == 4) { // 2d - ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); strides = { sH, sW }; kernel = { kH, kW }; @@ -212,7 +214,7 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, } else { // 3d - ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); + ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); strides = { sD, sH, sW }; kernel = { kD, kH, kW }; @@ -280,8 +282,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, std::unordered_map args; // gradO - mkldnnUtils::loadDataToMklStream(gradO, engine, stream, args, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); - + mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); @@ -291,8 +293,8 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, if(mode == algorithm::pooling_max) { // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, args, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); - + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + // z auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); args[DNNL_ARG_DST] = z_mkl_mem; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 29b5ebf2a..dd512a884 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -131,7 +131,7 @@ namespace sd { * @param reference to memory descriptor * @return memory format */ - void setBlockStrides(const NDArray* array, const int rank, dnnl::memory::desc& mklMd); + void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd); ////////////////////////////////////////////////////////////////////// /** * This function load and reorder user memory to mkl @@ -143,8 +143,8 @@ namespace sd { * @param primitive memory descriptor * @param dnnl arg activation enumerator */ - void loadDataToMklStream(const NDArray* array, dnnl::engine& engine, dnnl::stream& stream, - std::unordered_map& args, dnnl::memory::desc& user_md, dnnl::memory::desc primitive_md, int DNNL_ARG); + void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, + dnnl::memory& arg); /** * Utility methods for MKLDNN diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp index d67d205da..a178e84c2 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp @@ -55,12 +55,12 @@ namespace sd { dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + mkldnnUtils::setBlockStrides(x, x_user_md); // z dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); - mkldnnUtils::setBlockStrides(z, xRank, z_user_md); + mkldnnUtils::setBlockStrides(z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -80,7 +80,7 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // z auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); @@ -156,19 +156,19 @@ namespace sd { // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + mkldnnUtils::setBlockStrides(x, x_user_md); // dLdx dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); // todo if mkl does not support broadcast we can remove this format = mkldnnUtils::getFormat(dLdzRank); // dLdz dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdz, dLdzRank, dLdz_user_md); + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -188,7 +188,7 @@ namespace sd { // provide memory buffers and check whether reorder is required for forward // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, argsff, x_user_md, op_ff_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]); // dLdx auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); @@ -200,7 +200,7 @@ namespace sd { argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; // dLdz - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, argsbp, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]); // run calculations forward dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index 5a3ab0f57..a82bc2706 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -44,12 +44,12 @@ namespace sd { dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + mkldnnUtils::setBlockStrides(x, x_user_md); // z dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(z, xRank, z_user_md); + mkldnnUtils::setBlockStrides(z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -68,7 +68,7 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // z auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); @@ -132,17 +132,17 @@ namespace sd { dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, xRank, x_user_md); + mkldnnUtils::setBlockStrides(x, x_user_md); // dLdz dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdz, xRank, dLdz_user_md); - + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + // dLdx dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdx, xRank, dLdx_user_md); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -162,10 +162,10 @@ namespace sd { // provide memory buffers and check whether reorder is required for forward // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, args, x_user_md, op_prim_desc.src_desc(), DNNL_ARG_SRC); + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // dLdz - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, args, dLdz_user_md, op_prim_desc.diff_dst_desc(), DNNL_ARG_DIFF_DST); + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // dLdx auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->getBuffer()); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index f7e1ae7b9..149ab3c5f 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -178,7 +178,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) { ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -268,7 +268,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_7) { ASSERT_EQ(Status::OK(), results.status()); - + } ////////////////////////////////////////////////////////////////////// @@ -309,6 +309,72 @@ TEST_F(ConvolutionTests1, conv2d_8) { } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_9) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kH, kW}, {-3., -1.8, -0.6, 0.6, 1.8, 3., -2.7, -1.5, -0.3, 0.9, 2.1, 3.3, -2.4, -1.2, 0., 1.2, 2.4, 3.6, -2.1, -0.9, 0.3, 1.5, + 2.7, 3.9, -2.9, -1.7, -0.5, 0.7, 1.9, 3.1, -2.6, -1.4, -0.2, 1., 2.2, 3.4, -2.3, -1.1, 0.1, 1.3, 2.5, 3.7, -2., -0.8, 0.4, 1.6, + 2.8, 4., -2.8, -1.6, -0.4, 0.8, 2., 3.2, -2.5, -1.3, -0.1, 1.1, 2.3, 3.5, -2.2, -1., 0.2, 1.4, 2.6, 3.8, -1.9, -0.7, 0.5, 1.7, 2.9, 4.1}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {37.699997, 32.300041, 21.499989, 16.100004, 74.900024, 68.300003, 55.100006, 48.499969, 107.599983, 99.799988, + 84.200005, 76.400009, -221.5, -226.899994, -237.699997, -243.099991, -241.899994, -248.5, -261.700012, -268.299988, + -266.799988, -274.600006, -290.200012, -298.}, sd::DataType::FLOAT32); + + input.linspace(25,-0.5); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_10) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kH, kW, iC}, {-3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, + 3.6, 3.9, -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, + 3.1, 3.4, 3.7, 4., -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, + 2.9, 3.2, 3.5, 3.8, 4.1}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {463.400055, 498.800018, 529.700012, 410.600006, 442.799988, 470.500031, 113.600006, 130.400009, 142.699982, + -63.999958, -19.600082, 20.300007, -85.600052, -45.999939, -10.899940, -144.100021, -124., -108.399994, -128.799988, -98.799973, -73.300011, + -150.400009, -125.200012, -104.500008, -133.300003, -120.399994, -112.000008, -170.199997, -154., -142.299988, -146.200012, -133.199997, -124.699997, + -88.000008, -80.800003, -78.099991, -170.200012, -173.199997, -180.699982, -223., -229.199997, -239.900009, -88., -90.400002, -97.300003, -323.200012, + -336.399994, -354.100037, -344.800018, -362.799988, -385.299957, -100.900002, -109.600006, -122.800003, -388.000031, -415.599976, -447.700012, -409.599976, + -442., -478.900024, -90.099991, -105.999992, -126.399994, 117.800003, 95.599991, 68.899994, 141.799988, 116.399994, 86.5, 171.200012, 159.200012, 142.699997}, sd::DataType::FLOAT32); + + input.linspace(25,-0.5); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_1) { float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; @@ -542,7 +608,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) { ASSERT_TRUE(expBGrad.equalsTo(gradB)); - + } @@ -587,7 +653,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { //epsilon->printBuffer("Result buffer"); ASSERT_TRUE(expEps.equalsTo(epsilon)); - + } TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { @@ -669,14 +735,13 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { ASSERT_TRUE(z2d->equalsTo(&exp2FF)); } -TEST_F(ConvolutionTests1, TestDeconv_bp_1) { +TEST_F(ConvolutionTests1, deconv2d_bp_1) { int bS=3, iH=4,iW=4, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int oH=4,oW=4; int paddingMode = 1; // 1-SAME, 0-VALID; int dataFormat = 0; // 1-NHWC, 0-NCHW - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); NDArray bias('c', {oC}, sd::DataType::FLOAT32); NDArray weights('c',{kH,kW,oC,iC}, {1,3,5,2,4,6}, sd::DataType::FLOAT32); @@ -707,7 +772,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_1) { auto gradI = results.at(0); auto gradW = results.at(1); auto gradB = results.at(2); - + ASSERT_TRUE(expGradI.isSameShape(gradI)); ASSERT_TRUE(expGradI.equalsTo(gradI)); @@ -719,47 +784,95 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_1) { } -TEST_F(ConvolutionTests1, TestDeconv_bp_2) { - /* - Input shape: - [3, 3, 14, 14] - Output shape: - [3, 2, 15, 15] - Weights shape: - [3, 2, 2, 2] - Bias shape: - [1, 2] - weight shape: - [3, 2, 2, 2] - weight grad shape: - [3, 2, 2, 2] - bias grad shape: - [2] - input epsilon shape: - [3, 2, 15, 15] - output epsilon shape: - [3, 3, 14, 14] - */ - /* - auto input('c', {3, 3, 14, 14}); - auto bias('c', {2}); - auto weights('c',{3, 2, 2, 2}); - auto epsilon('c', {3, 2, 15, 15}); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_2) { + int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=4; // 5,4 + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - input.linspace(1); - weights.linspace(1); - bias.linspace(1); - epsilon.linspace(1); + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, oC, kH, kW}, {1., 7., 2., 10., 3., 8., 4., 11., 5., 9., 6., 12.}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); - sd::ops::deconv2d_bp op; + NDArray expGradI('c', {bS, iC, iH, iW}, {-77.400002, -77.199997, -77., -76.800003, -76.599998, -76.400002, -76.200005, -76., -75.800003, -75.599998, -75.399994, + -75.199997, -11.32, -11.29, -11.26, -11.23, -100.839996, -100.580002, -100.32, -100.059998, -99.800003, -99.540001, -99.279999, -99.019997, -98.760002, -98.50, + -98.240005, -97.979996, -26.52, -26.450001, -26.380001, -26.309999, -124.279999, -123.959991, -123.639999, -123.32, -123., -122.68, -122.360001, -122.040001, + -121.720001, -121.400009, -121.080002, -120.759995, -41.720001, -41.610001, -41.50, -41.389999, -71., -70.800003, -70.599998, -70.399994, -70.199997, -70., -69.800003, -69.600006, -69.400002, -69.199997, -69., -68.799995, -10.360001, -10.33, -10.30, -10.27, -92.519997, -92.260002, -92., -91.740005, -91.479996, -91.220001, -90.960007, -90.700005, -90.440002, -90.18, -89.919998, -89.660004, -24.280001, -24.209999, -24.139999, -24.07, -114.040001, -113.720001, -113.400009, -113.080002, -112.759995, -112.440002, -112.120003, -111.800003, -111.480003, -111.159996, -110.839996, -110.520004, -38.200001, -38.09, -37.980003, -37.869999, -64.599998, -64.400002, -64.199997, -64., -63.799995, -63.599998, -63.400002, -63.199997, -63., -62.799995, -62.599998, -62.400002, -9.40, -9.37, -9.34, -9.309999, -84.200005, -83.940002, -83.68, -83.419998, -83.160004, -82.900002, -82.639999, -82.379997, -82.119995, -81.860001, -81.600006, -81.339996, -22.040001, -21.970001, -21.90, -21.83, -103.800003, -103.480003, -103.159996, -102.839996, -102.520004, -102.200005, -101.879997, -101.559998, -101.239998, -100.919998, -100.599998, -100.279999, -34.68, -34.57, -34.459999, -34.349998}, sd::DataType::FLOAT32); - auto result = op.evaluate({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); + NDArray expGradW('c', {iC, oC, kH, kW}, {-3010.799805, -2502.420410, -2899.439209, -2407.380615, -242.159332, -437.460510, -253.680466, -434.580048, 2526.479980, 1627.500000, 2392.079834, 1538.220093}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-173.040009, -165.360016}, sd::DataType::FLOAT32); + input.linspace(70., -1); + gradO.linspace(-4, 0.01); - */ + sd::ops::deconv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_3) { + + int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=5,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, kH, kW, oC}, {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iH, iW, iC}, {-86.5, -102.320007, -118.139999, -86.060005, -101.800003, -117.540001, -85.619995, -101.279999, -116.940002, -85.18, + -100.759995, -116.339996, -84.740005, -100.239998, -115.739998, -84.300003, -99.720001, -115.139999, -83.860001, -99.199997, -114.539993, -83.419998, -98.68, + -113.939995, -82.979996, -98.160004, -113.339996, -82.539993, -97.639999, -112.739998, -82.099998, -97.120003, -112.139999, -81.660004, -96.600006, -111.539993, + -81.220001, -96.080002, -110.939995, -80.779999, -95.559998, -110.340012, -80.340004, -95.040001, -109.740005, -79.900002, -94.519997, -109.139992, -77.699997, + -91.919998, -106.139999, -77.260002, -91.400002, -105.540001, -76.820007, -90.880005, -104.940002, -76.380005, -90.360001, -104.339996, -75.940002, -89.839996, -103.740005, -75.5, -89.320007, -103.139999, -75.060005, -88.800003, -102.540001, -74.619995, -88.279999, -101.940002, -74.18, -87.759995, -101.339996, -73.740005, -87.239998, -100.739998, -73.300003, -86.720001, -100.139999, -72.860001, -86.199997, -99.539993, -72.419998, -85.68, -98.939995, -71.979996, -85.160004, -98.339996, -71.539993, -84.639999, -97.740005, -71.099998, -84.120003, -97.139999, -68.899994, -81.519997, -94.139999, -68.459999, -81.00, -93.539993, -68.019997, -80.479996, -92.940002, -67.580002, -79.959999, -92.339996, -67.139999, -79.440002, -91.740005, -66.699997, -78.919998, -91.139999, -66.260002, -78.399994, -90.540001, -65.820007, -77.880005, -89.940002, -65.380005, -77.360001, -89.339996, -64.940002, -76.839996, -88.740005, -64.5, -76.320007, -88.139999, -64.060005, -75.800003, -87.540001, -63.619995, -75.279999, -86.940002, -63.18, -74.759995, -86.339996, -62.739998, -74.239998, -85.739998, -62.299999, -73.720001, -85.139999}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kH, kW, oC}, {-592.800110, -593.039917, -594.719116, -594.960266, -427.199890, -427.919617, -432.959900, -433.679993, -261.600281, -262.799591, -271.200317, -272.399536}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-204.600006, -204.}, sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { auto input = NDArrayFactory::create('c', {2, 2, 6}); auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); @@ -1257,8 +1370,6 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { ASSERT_TRUE(expGradB.isSameShape(gradB)); ASSERT_TRUE(expGradB.equalsTo(gradB)); - - } ////////////////////////////////////////////////////////////////////// @@ -1289,6 +1400,105 @@ TEST_F(ConvolutionTests1, conv2d_bp_4) { ASSERT_EQ(Status::OK(), status); } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_5) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kH, kW}, {3.6, 2.4, 1.2, 0.0, -1.2, -2.4, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, 3.0, 1.8, 0.6, -0.6, -1.8, -3.0, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, 3.2, 2.0, 0.8, -0.4, -1.6, -2.8, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, 2.6, 1.4, 0.2, -1.0, -2.2, -3.4, 3.4, 2.2, 1.0, -0.2, -1.4, -2.6, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, 2.8, 1.6, 0.4, -0.8, -2.0, -3.2, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iH, iW},{0.517, 0.959, 0.406, 0.884, 1.474, 0.518, 0.020, -0.398, -0.490, -0.281, -0.853, -0.608, 0.472, 0.860, 0.352, 0.776, 1.240, + 0.392, -0.088, -0.632, -0.616, -0.344, -0.988, -0.680, 0.427, 0.761, 0.298, 0.668, 1.006, 0.266, -0.196, -0.866, -0.742, -0.407, -1.123, -0.752, 0.382, 0.662, + 0.244, 0.560, 0.772, 0.140, -0.304, -1.100, -0.868, -0.470, -1.258, -0.824, 1.777, 3.047, 1.234, 2.540, 3.922, 1.310, -0.052, -1.406, -1.426, -0.749, -2.221, + -1.508, 1.624, 2.732, 1.072, 2.216, 3.256, 0.968, -0.376, -2.072, -1.768, -0.920, -2.572, -1.688, 1.471, 2.417, 0.910, 1.892, 2.590, 0.626, -0.700, -2.738, -2.110, + -1.091, -2.923, -1.868, 1.318, 2.102, 0.748, 1.568, 1.924, 0.284, -1.024, -3.404, -2.452, -1.262, -3.274, -2.048}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {oC, iC, kH, kW},{-3.3, -2.62, -1.26, -0.58, 0.78, 1.46, 4.86, 5.54, 6.9, 7.58, 8.940001, 9.619999, 13.02, 13.700001, 15.06, 15.74, 17.1, + 17.780001, 21.18, 21.860001, 23.219999, 23.900002, 25.259998, 25.940001, -10.340001, -9.34, -7.339999, -6.34, -4.339999, -3.339999, 1.66, 2.66, 4.660001, + 5.660001, 7.66, 8.66, 13.66, 14.660001, 16.66, 17.66, 19.66, 20.66, 25.66, 26.66, 28.66, 29.66, 31.66, 32.66, -17.380001, -16.059999, -13.420003, -12.099999, + -9.46, -8.139999, -1.540001, -0.219999, 2.419999, 3.739999, 6.379999, 7.7, 14.299999, 15.62, 18.26, 19.58, 22.219999, 23.539999, 30.139999, 31.459999, 34.099998, + 35.419998, 38.060001, 39.380001}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {0.68, 1., 1.32}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + // weights.linspace(3.6, -0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_6) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kH, kW, iC}, {3.6, 0.0, 3.3, -0.3, 3.0, -0.6, 2.7, -0.9, 3.5, -0.1, 3.2, -0.4, 2.9, -0.7, 2.6, -1.0, 3.4, -0.2, 3.1, -0.5, 2.8, -0.8, 2.5, -1.1, 2.4, -1.2, 2.1, -1.5, 1.8, -1.8, 1.5, -2.1, 2.3, -1.3, 2.0, -1.6, 1.7, -1.9, 1.4, -2.2, 2.2, -1.4, 1.9, -1.7, 1.6, -2.0, 1.3, -2.3, 1.2, -2.4, 0.9, -2.7, 0.6, -3.0, 0.3, -3.3, 1.1, -2.5, 0.8, -2.8, 0.5, -3.1, 0.2, -3.4, 1.0, -2.6, 0.7, -2.9, 0.4, -3.2, 0.1, -3.5}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iH, iW, iC}, {0.882, -0.522, 0.765, -0.639, 1.953, -1.503, 1.665, -1.791, 2.691, -2.061, 2.295, -2.457, 2.259, -1.305, 1.962, -1.602, 4.545, + -3.555, 3.870, -4.230, 5.625, -4.419, 4.788, -5.256001, 4.122, -2.358, 3.582, -2.898, 7.785, -6.147, 6.624, -7.308, 8.865, -7.011, 7.541999, -8.334, 3.273, -2.019, + 2.832, -2.460, 6.069, -5.163, 5.133, -6.099, 6.771, -5.757, 5.727, -6.801, 5.958, -3.222, 5.193, -3.987, 10.809, -8.198999, 9.225, -9.783, 11.547, -8.757, 9.855, + -10.448999, 9.711, -5.517, 8.441999, -6.786, 17.505001, -13.922999, 14.886, -16.542, 18.585001, -14.787001, 15.804001, -17.568001, 11.574, -6.570, 10.062, -8.082, + 20.745001, -16.514999, 17.639999, -19.619999, 21.825001, -17.379002, 18.558001, -20.646, 8.133, -4.935, 7.044, -6.024, 14.492998, -12.291, 12.261, -14.523001, 15.195001, -12.885, 12.855, -15.225}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {oC, kH, kW, iC},{34.559998, 41.760010, 48.959999, 56.160004, 33.119999, 37.739998, 42.360001, 46.979996, 120.960007, 129.480011, 138.0, 146.519989, + 91.200005, 96.639999, 102.079994, 107.520004, 114.479996, 120.059998, 125.639999, 131.220001, 82.080002, 85.620003, 89.160004, 92.699997, 33.120003, 40.499996, + 47.879993, 55.260002, 32.399998, 37.139996, 41.880001, 46.620003, 120.479988, 129.240005, 137.999985, 146.759995, 91.199997, 96.799995, 102.399994, 108.0, 115.199989, + 120.959999, 126.720001, 132.479996, 82.799995, 86.460007, 90.119995, 93.779999, 31.679998, 39.239994, 46.800003, 54.359997, 31.680000, 36.540001, 41.400002, 46.260002, + 120.0, 129.0, 138.0, 147.0, 91.200005, 96.960007, 102.720001, 108.480003, 115.919998, 121.860001, 127.799988, 133.740005, 83.520004, 87.300003, 91.080002, 94.860001}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {8.520, 8.760, 9.}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { @@ -1335,7 +1545,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { ASSERT_TRUE(expGradW.isSameShape(gradW)); ASSERT_TRUE(expGradW.equalsTo(gradW)); - + } @@ -1383,9 +1593,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { ASSERT_TRUE(expGradW.isSameShape(gradW)); ASSERT_TRUE(expGradW.equalsTo(gradW)); - -} +} //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { @@ -1441,140 +1650,50 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { ASSERT_TRUE(expGradB.isSameShape(gradB)); ASSERT_TRUE(expGradB.equalsTo(gradB)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int oC=iC*mC; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, - 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) { +TEST_F(ConvolutionTests1, conv3d_bp_test4) { - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kD, kH, kW}, {7., 5.8, 4.6, 3.4, 2.2, 1., -0.2, -1.4, -2.6, -3.8, -5., -6.2, 6.7, 5.5, 4.3, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, -4.1, + -5.3, -6.5, 6.4, 5.2, 4., 2.8, 1.6, 0.4, -0.8, -2., -3.2, -4.4, -5.6, -6.8, 6.1, 4.9, 3.7, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5, -4.7, -5.9, -7.1, 6.9, 5.7, 4.5, + 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, -3.9, -5.1, -6.3, 6.6, 5.4, 4.2, 3., 1.8, 0.6, -0.6, -1.8, -3., -4.2, -5.4, -6.6, 6.3, 5.1, 3.9, 2.7, 1.5, 0.3, -0.9, -2.1, + -3.3, -4.5, -5.7, -6.9, 6., 4.8, 3.6, 2.4, 1.2, 0., -1.2, -2.4, -3.6, -4.8, -6., -7.2, 6.8, 5.6, 4.4, 3.2, 2., 0.8, -0.4, -1.6, -2.8, -4., -5.2, -6.4, 6.5, 5.3, 4.1, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, -4.3, -5.5, -6.7, 6.2, 5., 3.8, 2.6, 1.4, 0.2, -1., -2.2, -3.4, -4.6, -5.8, -7., 5.9, 4.7, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, -3.7, -4.9, -6.1, -7.3}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); - NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, - 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, sd::DataType::FLOAT32); + NDArray expGradI('c', {bS, iC, iD, iH, iW},{1.847, 3.577, 1.694, 3.460, 6.542, 3.010, 1.469, 2.677, 1.172, 3.226, 5.929999, 2.632, 5.408, 9.483999, 3.932, 1.894, + 2.978, 1.012, 0.058, -0.694, -0.824, -1.504, -4.916, -3.556, -1.850, -4.798, -3.020, -1.069, -2.687, -1.654, -3.236, -7.714, -4.550, -2.311, -5.315, -3.040, + 1.766, 3.406, 1.604, 3.280, 6.164, 2.812, 1.370, 2.470, 1.064, 3.028, 5.516, 2.416, 4.976, 8.584001, 3.464, 1.660, 2.492, 0.760, -0.140, -1.108, -1.040, -1.936, + -5.816, -4.024, -2.084, -5.284, -3.272, -1.186, -2.930, -1.780, -3.488, -8.236, -4.820, -2.446, -5.594, -3.184, 1.685, 3.235, 1.514, 3.100, 5.786, 2.614, 1.271, + 2.263, 0.956, 2.830, 5.102, 2.200, 4.544001, 7.683999, 2.996, 1.426, 2.006, 0.508, -0.338, -1.522, -1.256, -2.368, -6.716, -4.492, -2.318, -5.770, -3.524, -1.303, + -3.173, -1.906, -3.740, -8.757999, -5.090, -2.581, -5.873, -3.328, 1.604, 3.064, 1.424, 2.920, 5.408, 2.416, 1.172, 2.056, 0.848, 2.632, 4.688, 1.984, 4.112, 6.784, 2.528, 1.192, 1.520, 0.256, -0.536, -1.936, -1.472, -2.800, -7.616, -4.960, -2.552, -6.256, -3.776, -1.420, -3.416, -2.032, -3.992, -9.280001, -5.360, -2.716, -6.152, -3.472, 6.815001, 12.649, 5.798, 11.668, 21.230, 9.490, 4.709, 8.292999, 3.548, 9.706, 17.162001, 7.384, 14.912001, 25.036001, 9.980001, 4.918, 7.298, 2.308, -0.374, -3.286, -2.984, -5.824, -17.012001, -11.332001, -5.738, -14.302, -8.636, -3.013, -7.439, -4.462, -8.852, -20.674, -11.894, -5.983, -13.523, -7.576, 6.518, 12.046, 5.492, 11.056, 19.988001, 8.860001, 4.394, 7.654, 3.224, 9.075999, 15.883999, 6.736001, 13.616, 22.407999, 8.648, 4.252, 5.947999, 1.624, -1.004, -4.564, -3.632, -7.120, -19.639999, -12.664001, -6.404, -15.652, -9.320, -3.346, -8.114, -4.804, -9.536, -22.059999, -12.596, -6.334, -14.233999, -7.936, 6.221, 11.443, 5.186, 10.444, 18.746, 8.230, 4.079, 7.015, 2.900, 8.446, 14.606001, 6.088, 12.320, 19.779999, 7.316, 3.586, 4.598001, 0.940, -1.634, -5.842, -4.280, -8.416, -22.268002, -13.996, -7.070001, -17.001999, -10.004001, -3.679, -8.789, -5.146, -10.220, -23.445999, -13.298, -6.684999, -14.945, -8.296, 5.924, 10.840, 4.880, 9.832001, 17.504, 7.600, 3.764, 6.376, 2.576, 7.816, 13.328, 5.440001, 11.024, 17.152, 5.983999, 2.920, 3.247999, 0.256, -2.264, -7.120, -4.928, -9.712, -24.896, -15.328, -7.736, -18.352001, -10.688, -4.012, -9.464, -5.488, -10.903999, -24.832001, -14.000, -7.035999, -15.656, -8.655999}, sd::DataType::FLOAT32); - input = 2.; - weights.linspace(0.1, 0.1); + NDArray expGradW('c', {oC, iC, kD, kH, kW},{-24.399998, -23.080000, -20.440001, -19.119999, -12.519999, -11.199998, -8.560001, -7.240002, -0.639999, 0.679999, + 3.320001, 4.640001, 23.119999, 24.439999, 27.080002, 28.400002, 35.000000, 36.320000, 38.959999, 40.279999, 46.879997, 48.200005, 50.839996, 52.160004, + 70.639999, 71.959999, 74.599998, 75.919998, 82.520004, 83.840004, 86.479996, 87.800003, 94.399994, 95.719994, 98.360001, 99.680008, 118.160004, 119.479996, + 122.120003, 123.440010, 130.040009, 131.360001, 134.000000, 135.319992, 141.919998, 143.239990, 145.879990, 147.200012, -70.159996, -68.200005, -64.279999, + -62.319996, -52.519993, -50.559994, -46.640003, -44.680000, -34.880001, -32.919998, -29.000002, -27.040005, 0.400004, 2.359996, 6.279998, 8.240004, 18.040001, + 20.000000, 23.920002, 25.879999, 35.680000, 37.639996, 41.560001, 43.520000, 70.959999, 72.919998, 76.840004, 78.799995, 88.599998, 90.560005, 94.479996, 96.440002, 106.240005, 108.199997, 112.120003, 114.080002, 141.519989, 143.479996, 147.400009, 149.360001, 159.159988, 161.119995, 165.040009, 167.000000, 176.800003, 178.760010, 182.679993, 184.639999, -115.920006, -113.320000, -108.120003, -105.520012, -92.520004, -89.919991, -84.720001, -82.119995, -69.120010, -66.520004, -61.320000, -58.719994, -22.320000, -19.719999, -14.520001, -11.920001, 1.079997, 3.679997, 8.879997, 11.480003, 24.480001, 27.079998, 32.280003, 34.880001, 71.279999, 73.880005, 79.080002, 81.680000, 94.679993, 97.280006, 102.479996, 105.080002, 118.080002, 120.679993, 125.879997, 128.479996, 164.880005, 167.479996, 172.679993, 175.279999, 188.279984, 190.880005, 196.080002, 198.679993, 211.680008, 214.280014, 219.479996, 222.079987}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {2.64, 3.92, 5.2}, sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); gradO.linspace(0.01, 0.01); - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test3) { - - auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); - auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); - auto b = NDArrayFactory::create('c', {1, 16}); - auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); - - auto gradI = in.like(); - auto gradW = w.like(); - auto gradB = b.like(); - - nd4j:ops::depthwise_conv2d_bp op; - auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); - ASSERT_EQ(Status::OK(), status); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - gradO.linspace(10, -0.1); - - - NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008, - 93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, - -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, - -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001, - -408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002, - -788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, - -107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062, - -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500, - -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000, - -143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* gradI = results.at(0); - NDArray* gradW = results.at(1); - NDArray* gradB = results.at(2); + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); ASSERT_EQ(Status::OK(), results.status()); @@ -1586,49 +1705,49 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) { ASSERT_TRUE(expGradB.isSameShape(gradB)); ASSERT_TRUE(expGradB.equalsTo(gradB)); - } ////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) { +TEST_F(ConvolutionTests1, conv3d_bp_test5) { - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=3,oW=3; int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kD, kH, kW, iC}, {15., 14.7, 14.4, 14.1, 13.8, 13.5, 13.2, 12.9, 12.6, 12.3, 12., 11.7, 11.4, 11.1, 10.8, 10.5, 10.2, 9.9, 9.6, 9.3, 9., + 8.7, 8.4, 8.1, 7.8, 7.5, 7.2, 6.9, 6.6, 6.3, 6., 5.7, 5.4, 5.1, 4.8, 4.5, 4.2, 3.9, 3.6, 3.3, 3., 2.7, 2.4, 2.1, 1.8, 1.5, 1.2, 0.9, 14.9, 14.6, 14.3, 14., + 13.7, 13.4, 13.1, 12.8, 12.5, 12.2, 11.9, 11.6, 11.3, 11., 10.7, 10.4, 10.1, 9.8, 9.5, 9.2, 8.9, 8.6, 8.3, 8., 7.7, 7.4, 7.1, 6.8, 6.5, 6.2, 5.9, 5.6, 5.3, 5., + 4.7, 4.4, 4.1, 3.8, 3.5, 3.2, 2.9, 2.6, 2.3, 2., 1.7, 1.4, 1.1, 0.8, 14.8, 14.5, 14.2, 13.9, 13.6, 13.3, 13., 12.7, 12.4, 12.1, 11.8, 11.5, 11.2, 10.9, 10.6, + 10.3, 10., 9.7, 9.4, 9.1, 8.8, 8.5, 8.2, 7.9, 7.6, 7.3, 7., 6.7, 6.4, 6.1, 5.8, 5.5, 5.2, 4.9, 4.6, 4.3, 4., 3.7, 3.4, 3.1, 2.8, 2.5, 2.2, 1.9, 1.6, 1.3, 1., 0.7}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - gradO.linspace(10, -0.1); + NDArray expGradI('c', {bS, iD, iH, iW, iC}, {13.565001, 13.286001, 13.007000, 12.728001, 28.264000, 27.652000, 27.040001, 26.427999, 32.547997, 31.827999, 31.108002, + 30.388000, 31.647999, 30.927998, 30.208000, 29.487999, 64.484001, 62.935997, 61.387997, 59.839996, 72.188004, 70.424004, 68.660004, 66.896004, 43.852001, 42.807999, + 41.764000, 40.719997, 87.596001, 85.400002, 83.204002, 81.007996, 95.299988, 92.887993, 90.475998, 88.063995, 34.130997, 33.348000, 32.564999, 31.782001, 67.856995, + 66.210007, 64.563004, 62.916000, 72.987000, 71.178001, 69.369003, 67.559998, 70.179001, 68.369995, 66.561005, 64.751999, 137.927994, 134.147995, 130.367996, 126.587997, + 146.891998, 142.787994, 138.683990, 134.580017, 84.597000, 82.302002, 80.007004, 77.711998, 164.820007, 160.067993, 155.316010, 150.563995, 173.783997, 168.707993, + 163.631989, 158.556000, 58.674000, 57.162003, 55.649994, 54.138000, 114.027008, 110.921997, 107.816994, 104.711990, 119.156998, 115.889999, 112.623001, 109.355995, 113.433006, 110.166000, 106.899002, 103.632004, 218.603989, 211.908020, 205.211975, 198.515991, 227.568008, 220.547974, 213.528015, 206.507996, 127.850998, 124.098000, 120.345001, 116.591995, 245.496002, 237.828018, 230.159988, 222.492004, 254.459991, 246.468002, 238.475998, 230.483994, 34.049000, 32.797997, 31.547001, 30.295998, 64.479996, 61.924000, 59.368004, 56.812000, 67.035995, 64.372002, 61.707996, 59.044003, 62.248001, 59.584003, 56.919998, 54.256001, 116.180000, 110.744003, 105.307999, 99.872002, 120.428001, 114.776001, 109.124001, 103.472000, 69.268005, 66.279999, 63.292000, 60.304001, 128.923996, 122.839996, 116.755997, 110.671997, 133.171997, 126.872002, 120.571991, 114.271996, 94.565002, 92.342010, 90.118996, 87.896004, 182.488007, 177.988007, 173.488007, 168.988007, 186.772003, 182.164001, 177.556000, 172.947998, 178.095993, 173.488007, 168.880005, 164.272003, 341.828003, 332.504028, 323.180023, 313.856018, 349.532013, 339.992004, 330.451996, 320.911987, 190.299988, 185.368011, 180.436005, 175.503998, 364.940002, 354.967987, 344.996002, 335.024017, 372.644012, 362.455994, 352.268005, 342.080017, 132.303009, 128.604004, 124.904999, 121.206001, 252.536987, 245.057999, 237.578979, 230.100006, 257.666992, 250.026001, 242.385010, 234.744019, 243.195007, 235.554001, 227.912994, 220.272003, 460.631958, 445.188019, 429.744019, 414.299988, 469.595947, 453.827972, 438.059998, 422.291992, 257.613007, 249.486008, 241.358994, 233.232010, 487.523987, 471.108032, 454.691986, 438.276001, 496.488037, 479.748016, 463.007996, 446.268005, 156.846008, 152.417999, 147.989990, 143.561996, 298.707001, 289.769989, 280.833008, 271.895996, 303.837006, 294.737976, 285.638977, 276.540009, 286.449005, 277.350006, 268.250977, 259.151978, 541.307983, 522.947998, 504.587982, 486.227997, 550.271973, 531.588013, 512.903992, 494.220032, 300.867004, 291.281982, 281.696991, 272.112000, 568.200012, 548.868042, 529.535950, 510.204010, 577.164062, 557.507935, 537.851990, 518.196045, 83.944992, 80.750000, 77.555000, 74.360001, 156.496002, 150.052002, 143.608002, 137.164001, 159.052002, 152.500000, 145.947998, 139.395996, 146.488007, 139.936005, 133.384003, 126.832001, 269.107971, 255.895996, 242.684006, 229.471985, 273.356018, 259.927979, 246.500000, 233.071991, 153.507996, 146.632004, 139.755997, 132.880005, 281.851990, 267.992004, 254.132004, 240.272003, 286.100006, 272.023987, 257.947998, 243.872009}, sd::DataType::FLOAT32); + NDArray expGradW('c', {oC, kD, kH, kW, iC}, {396.899872, 429.570007, 462.240234, 494.910156, 313.739960, 335.250000, 356.760071, 378.270020, 403.379944, 424.350006, + 445.320007, 466.289978, 299.520020, 313.319977, 327.119995, 340.920013, 1556.280029, 1594.979980, 1633.679932, 1672.379883, 1090.080078, 1115.520020, 1140.959961, + 1166.400024, 1183.679932, 1208.400024, 1233.119995, 1257.840088, 821.279907, 837.519897, 853.760010, 870.000000, 1500.119873, 1525.500122, 1550.880005, 1576.260010, + 1029.780029, 1046.429932, 1063.080078, 1079.729980, 1080.539917, 1096.650024, 1112.760010, 1128.869995, 738.000000, 748.560059, 759.119995, 769.679993, 389.880005, + 422.819946, 455.759979, 488.699951, 309.420013, 331.109985, 352.799988, 374.490051, 399.780029, 420.930023, 442.080017, 463.230011, 297.359985, 311.280029, 325.200012, 339.120056, 1553.400146, 1592.459961, 1631.520020, 1670.579956, 1088.640015, 1114.320068, 1140.000000, 1165.679932, 1183.199951, 1208.160034, 1233.119995, 1258.079956, 821.280029, 837.680054, 854.079956, 870.479980, 1502.819946, 1528.469971, 1554.119995, 1579.770020, 1031.939941, 1048.770020, 1065.599976, 1082.429932, 1083.420044, 1099.709961, 1116.000000, 1132.290039, 740.159973, 750.840027, 761.519958, 772.199951, 382.859924, 416.070099, 449.279968, 482.489990, 305.099976, 326.970062, 348.840027, 370.709991, 396.179962, 417.510010, 438.839966, 460.169952, 295.200012, 309.239990, 323.279968, 337.320007, 1550.519775, 1589.939941, 1629.359985, 1668.779907, 1087.200073, 1113.119995, 1139.039917, 1164.959961, 1182.719971, 1207.920044, 1233.119995, 1258.320190, 821.279968, 837.840027, 854.400024, 870.959961, 1505.520142, 1531.439819, 1557.359985, 1583.279907, 1034.100098, 1051.110107, 1068.120117, 1085.130005, 1086.299927, 1102.770020, 1119.239990, 1135.710083, 742.319946, 753.119995, 763.919983, 774.720032}, sd::DataType::FLOAT32); - NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000, - -60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, - -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912, - -477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012, - -730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980, - -1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {77.400002, 78.119995, 78.840004}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000, - -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951, - -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188, - -110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625, - -302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500, - -2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, sd::DataType::FLOAT32); + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); - NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* gradI = results.at(0); - NDArray* gradW = results.at(1); - NDArray* gradB = results.at(2); + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); ASSERT_EQ(Status::OK(), results.status()); @@ -1640,46 +1759,6 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) { ASSERT_TRUE(expGradB.isSameShape(gradB)); ASSERT_TRUE(expGradB.equalsTo(gradB)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test6) { - - int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, - 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, - 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - } ////////////////////////////////////////////////////////////////////// @@ -1689,37 +1768,6 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { int paddingMode = 1; // 1-SAME, 0-VALID; int dataFormat = 1; // 1-NDHWC, 0-NCDHW - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f}); - input = 2.; - weights = 1.; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - -} - - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, @@ -1744,7 +1792,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { } ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1771,7 +1819,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1795,7 +1843,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1824,7 +1872,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1852,7 +1900,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1879,11 +1927,11 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1906,11 +1954,11 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); auto y = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); @@ -1924,7 +1972,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { ASSERT_TRUE(e.isSameShape(z)); } -TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { +TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); auto w = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); @@ -1969,6 +2017,121 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { delete shapes; } +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { + + int bS=1, iD=2,iH=2,iW=2, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + + input = 2.; + weights = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { + + int bS=5, iD=4,iH=14,iW=14, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=13,oW=13; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + input = 2.; + weights = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->isSameShape(&expected)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_test12) { + + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kD, kH, kW}, {-14.4, -13.2, -12.0, -10.8, -9.6, -8.4, -7.2, -6.0, -4.8, -3.6, -2.4, -1.2, -14.1, -12.9, -11.7, -10.5, -9.3, -8.1, + -6.9, -5.7, -4.5, -3.3, -2.1, -0.9, -13.8, -12.6, -11.4, -10.2, -9.0, -7.8, -6.6, -5.4, -4.2, -3.0, -1.8, -0.6, -13.5, -12.3, -11.1, -9.9, -8.7, -7.5, -6.3, + -5.1, -3.9, -2.7, -1.5, -0.3, -14.3, -13.1, -11.9, -10.7, -9.5, -8.3, -7.1, -5.9, -4.7, -3.5, -2.3, -1.1, -14.0, -12.8, -11.6, -10.4, -9.2, -8.0, -6.8, -5.6, + -4.4, -3.2, -2.0, -0.8, -13.7, -12.5, -11.3, -10.1, -8.9, -7.7, -6.5, -5.3, -4.1, -2.9, -1.7, -0.5, -13.4, -12.2, -11.0, -9.8, -8.6, -7.4, -6.2, -5.0, -3.8, -2.6, -1.4, -0.2, -14.2, -13.0, -11.8, -10.6, -9.4, -8.2, -7.0, -5.8, -4.6, -3.4, -2.2, -1.0, -13.9, -12.7, -11.5, -10.3, -9.1, -7.9, -6.7, -5.5, -4.3, -3.1, -1.9, -0.7, -13.6, -12.4, -11.2, -10.0, -8.8, -7.6, -6.4, -5.2, -4.0, -2.8, -1.6, -0.4, -13.3, -12.1, -10.9, -9.7, -8.5, -7.3, -6.1, -4.9, -3.7, -2.5, -1.3, -0.1}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-42520.597656, -42344.199219, -41991.402344, -41814.996094, -40932.992188, -40756.597656, -40403.800781, -40227.406250, + -41953.601562, -41779.601562, -41431.597656, -41257.601562, -40387.601562, -40213.597656, -39865.601562, -39691.597656, -41391.105469, -41219.492188, + -40876.300781, -40704.699219, -39846.707031, -39675.097656, -39331.898438, -39160.300781, -17119.001953, -16942.599609, -16589.798828, -16413.400391, + -15531.399414, -15355.000000, -15002.199219, -14825.800781, -16897.597656, -16723.597656, -16375.599609, -16201.599609, -15331.599609, -15157.600586, + -14809.601562, -14635.598633, -16680.703125, -16509.099609, -16165.900391, -15994.300781, -15136.300781, -14964.700195, -14621.500000, -14449.900391}, sd::DataType::FLOAT32); + + input.linspace(150,-0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_test13) { + + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kD, kH, kW, iC}, {-7., -6.7, -6.4, -6.1, -5.8, -5.5, -5.2, -4.9, -4.6, -4.3, -4., -3.7, -3.4, -3.1, -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, + -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, 3.8, 4.1, 4.4, 4.7, 5., 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, -6.9, -6.6, -6.3, + -6., -5.7, -5.4, -5.1, -4.8, -4.5, -4.2, -3.9, -3.6, -3.3, -3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, + 2.4, 2.7, 3., 3.3, 3.6, 3.9, 4.2, 4.5, 4.8, 5.1, 5.4, 5.7, 6., 6.3, 6.6, 6.9, 7.2, -6.8, -6.5, -6.2, -5.9, -5.6, -5.3, -5., -4.7, -4.4, -4.1, -3.8, -3.5, -3.2, + -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, 4., 4.3, 4.6, 4.9, 5.2, 5.5, 5.8, 6.1, 6.4, 6.7, 7., 7.3}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oD, oH, oW, oC}, {3969.399658, 4168.399902, 4362.899414, 3812.600586, 4005.200195, 4193.299805, 1317.000000, 1413.199829, 1504.899902, + 3498.999756, 3678.800049, 3854.100098, 3342.200195, 3515.599854, 3684.500244, 1139.400024, 1226.000000, 1308.099976, 685.799927, 772.400024, 854.500000, + 645.800049, 729.200073, 808.099976, 80.799995, 123.200012, 161.100006, -2851.000732, -2597.199707, -2347.899414, -2855.799805, -2611.600098, -2371.900879, + -2124.399414, -2003.199951, -1886.500244, -2865.399902, -2640.400146, -2419.899902, -2870.199951, -2654.800049, -2443.899902, -2045.200073, -1938.399902, + -1836.100220, -2596.000244, -2489.199707, -2386.900146, -2540.799561, -2438.800049, -2341.300049, -1539.699951, -1488.400024, -1441.599854, -2894.200195, + -2726.800049, -2563.899902, -2899.000488, -2741.199707, -2587.899658, -1886.800171, -1808.800049, -1735.300171, -2908.599121, -2770.000488, -2635.900146, -2913.400146, -2784.399658, -2659.899902, -1807.599976, -1743.999878, -1684.900146, -2099.199951, -2035.599976, -1976.500366, -2044.000244, -1985.199707, -1930.900024, -1161.699951, -1132.000122, -1106.800171, -2731.399902, -2647.599609, -2568.300293, -2580.999756, -2503.600098, -2430.699951, -1457.400024, -1418.800049, -1384.700073, -2280.200195, -2215.600098, -2155.500732, -2129.799561, -2071.600098, -2017.899780, -1174.200073, -1145.200195, -1120.699829, -1282.200073, -1253.199951, -1228.699951, -1168.599976, -1142.799927, -1121.500122, -615.199951, -601.600037, -592.500000, -1675.399658, -1706.800049, -1742.700073, -1832.200073, -1870.000000, -1912.299561, -814.199951, -833.200012, -856.699951, -2145.800049, -2196.399902, -2251.500244, -2302.600342, -2359.599854, -2421.100098, -991.800049, -1020.400024, -1053.500000, -754.199951, -782.800049, -815.900085, -794.199951, -825.999939, -862.299988, -293.600006, -308.800018, -328.500000, -3023.800293, -3115.600098, -3211.900391, -3028.599121, -3130.000244, -3235.899902, -1173.999878, -1225.600098, -1281.699951, -3038.200195, -3158.799805, -3283.899902, -3043.000000, -3173.199707, -3307.900391, -1094.800049, -1160.800049, -1231.300049, -608.799988, -674.799988, -745.300049, -553.599976, -624.400024, -699.700012, -27.700012, -62.799988, -102.400009, -3066.999512, -3245.199707, -3427.900391, -3071.800293, -3259.599854, -3451.900146, -936.400085, -1031.199951, -1130.500000, -3081.400146, -3288.400635, -3499.899414, -3086.200439, -3302.799805, -3523.899902, -857.199951, -966.400024, -1080.099976, -111.999969, -221.199936, -334.900024, -56.800079, -170.799988, -289.299927, 350.299927, 293.600037, 232.399979, 2683.000244, 2536.400146, 2385.300049, 2833.399658, 2680.400391, 2522.900391, 1940.999878, 1864.399902, 1783.300049, 3134.200195, 2968.399414, 2798.100098, 3284.600098, 3112.400391, 2935.699707, 2224.199707, 2138.000244, 2047.300049, 2807.399658, 2721.200195, 2630.500000, 2921.000000, 2831.599854, 2737.699707, 1775.200195, 1731.199951, 1682.699829}, sd::DataType::FLOAT32); + + input.linspace(75,-0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { @@ -1999,51 +2162,6 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { } -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { - - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - - input = 2.; - weights = 1.; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test12) { - - int bS=5, iD=4,iH=14,iW=14, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=13,oW=13; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - - input = 2.; - weights = 1.; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isSameShape(&expected)); - -} - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, vol2col_test1) { @@ -2406,7 +2524,6 @@ TEST_F(ConvolutionTests1, deconv2d_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - } ////////////////////////////////////////////////////////////////////// @@ -2437,7 +2554,6 @@ TEST_F(ConvolutionTests1, deconv2d_test2) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - } ////////////////////////////////////////////////////////////////////// @@ -2629,6 +2745,82 @@ TEST_F(ConvolutionTests1, deconv2d_test8) { } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test9) { + + int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kH, kW}, {100.000000, 75.000000, 50.000000, 25.000000, 95.000000, 70.000000, 45.000000, 20.000000, 90.000000, 65.000000, 40.000000, + 15.000000, 85.000000, 60.000000, 35.000000, 10.000000, 80.000000, 55.000000, 30.000000, 5.000000, 99.500000, 74.500000, 49.500000, 24.500000, 94.500000, 69.500000, + 44.500000, 19.500000, 89.500000, 64.500000, 39.500000, 14.500000, 84.500000, 59.500000, 34.500000, 9.500000, 79.500000, 54.500000, 29.500000, 4.500000, 99.000000, + 74.000000, 49.000000, 24.000000, 94.000000, 69.000000, 44.000000, 19.000000, 89.000000, 64.000000, 39.000000, 14.000000, 84.000000, 59.000000, 34.000000, 9.000000, + 79.000000, 54.000000, 29.000000, 4.000000, 98.500000, 73.500000, 48.500000, 23.500000, 93.500000, 68.500000, 43.500000, 18.500000, 88.500000, 63.500000, 38.500000, + 13.500000, 83.500000, 58.500000, 33.500000, 8.500000, 78.500000, 53.500000, 28.500000, 3.500000, 98.000000, 73.000000, 48.000000, 23.000000, 93.000000, 68.000000, + 43.000000, 18.000000, 88.000000, 63.000000, 38.000000, 13.000000, 83.000000, 58.000000, 33.000000, 8.000000, 78.000000, 53.000000, 28.000000, 3.000000, 97.500000, 72.500000, 47.500000, 22.500000, 92.500000, 67.500000, 42.500000, 17.500000, 87.500000, 62.500000, 37.500000, 12.500000, 82.500000, 57.500000, 32.500000, 7.500000, 77.500000, 52.500000, 27.500000, 2.500000, 97.000000, 72.000000, 47.000000, 22.000000, 92.000000, 67.000000, 42.000000, 17.000000, 87.000000, 62.000000, 37.000000, 12.000000, 82.000000, 57.000000, 32.000000, 7.000000, 77.000000, 52.000000, 27.000000, 2.000000, 96.500000, 71.500000, 46.500000, 21.500000, 91.500000, 66.500000, 41.500000, 16.500000, 86.500000, 61.500000, 36.500000, 11.500000, 81.500000, 56.500000, 31.500000, 6.500000, 76.500000, 51.500000, 26.500000, 1.500000, 96.000000, 71.000000, 46.000000, 21.000000, 91.000000, 66.000000, 41.000000, 16.000000, 86.000000, 61.000000, 36.000000, 11.000000, 81.000000, 56.000000, 31.000000, 6.000000, 76.000000, 51.000000, 26.000000, 1.000000, 95.500000, 70.500000, 45.500000, 20.500000, 90.500000, 65.500000, 40.500000, 15.500000, 85.500000, 60.500000, 35.500000, 10.500000, 80.500000, 55.500000, 30.500000, 5.500000, 75.500000, 50.500000, 25.500000, 0.500000}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oH, oW, oC}, {-30844.250000, -29266.750000, -27689.250000, -26111.750000, -24534.250000, -52823.500000, -49718.500000, -46613.500000, -43508.500000, -40403.500000, -51118.500000, + -48113.500000, -45108.500000, -42103.500000, -39098.500000, -21501.750000, -20024.250000, -18546.750000, -17069.250000, -15591.750000, -42981.000000, -39976.000000, -36971.000000, -33966.000000, -30961.000000, + -69482.000000, -63572.000000, -57662.000000, -51752.000000, -45842.000000, -67072.000000, -61362.000000, -55652.000000, -49942.000000, -44232.000000, -26046.000000, -23241.000000, -20436.000000, -17631.000000, + -14826.000000, -38616.000000, -35911.000000, -33206.000000, -30501.000000, -27796.000000, -62252.000000, -56942.000000, -51632.000000, -46322.000000, -41012.000000, -59842.000000, -54732.000000, -49622.000000, + -44512.000000, -39402.000000, -23181.000000, -20676.000000, -18171.000000, -15666.000000, -13161.000000, -12204.250000, -10926.750000, -9649.250000, -8371.750000, -7094.250000, -17543.500000, -15038.500000, + -12533.500000, -10028.500000, -7523.500000, -16838.500000, -14433.499023, -12028.500000, -9623.500000, -7218.500000, -5361.750000, -4184.250000, -3006.750000, -1829.250000, -651.750000, -22046.750000, -20919.250000, + -19791.750000, -18664.250000, -17536.750000, -37478.500000, -35273.500000, -33068.500000, -30863.500000, -28658.500000, -35773.500000, -33668.500000, -31563.500000, -29458.500000, -27353.500000, -14954.250000, + -13926.750000, -12899.250000, -11871.750000, -10844.250000, -29886.000000, -27781.000000, -25676.000000, -23571.000000, -21466.000000, -47792.000000, -43682.000000, -39572.000000, -35462.000000, -31352.000000, + -45382.000000, -41472.000000, -37562.000000, -33652.000000, -29742.000000, -17451.000000, -15546.000000, -13641.000000, -11736.000000, -9831.000000, -25521.000000, -23716.000000, -21911.000000, -20106.000000, -18301.000000, -40562.000000, -37052.000000, -33542.000000, -30032.000000, -26522.000000, -38152.000000, -34842.000000, -31532.000000, -28222.000000, -24912.000000, -14586.000000, -12981.000000, -11376.000000, -9771.000000, -8166.000000, -7906.750000, -7079.250000, -6251.750000, -5424.250000, -4596.750000, -11198.500000, -9593.500000, -7988.500000, -6383.500000, -4778.500000, -10493.500000, -8988.500000, -7483.500000, -5978.500000, -4473.500000, -3314.250000, -2586.750000, -1859.250000, -1131.750000, -404.250000}, sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test10) { + + int bS=2, oH=4,oW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int iH=4,iW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, kH, kW, oC}, {100., 95., 90., 85., 80., 75., 70., 65., 60., 55., 50., 45., 40., 35., 30., 25., 20., 15., 10., 5., 0., -5., -10., -15., + -20., -25., -30., -35., -40., -45., -50., -55., -60., -65., -70., -75., -80., -85., -90., -95., 99., 94., 89., 84., 79., 74., 69., 64., 59., 54., 49., 44., + 39., 34., 29., 24., 19., 14., 9., 4., -1., -6., -11., -16., -21., -26., -31., -36., -41., -46., -51., -56., -61., -66., -71., -76., -81., -86., -91., -96., + 98., 93., 88., 83., 78., 73., 68., 63., 58., 53., 48., 43., 38., 33., 28., 23., 18., 13., 8., 3., -2., -7., -12., -17., -22., -27., -32., -37., -42., -47., + -52., -57., -62., -67., -72., -77., -82., -87., -92., -97., 97., 92., 87., 82., 77., 72., 67., 62., 57., 52., 47., 42., 37., 32., 27., 22., 17., 12., 7., 2., + -3., -8., -13., -18., -23., -28., -33., -38., -43., -48., -53., -58., -63., -68., -73., -78., -83., -88., -93., -98., 96., 91., 86., 81., 76., 71., 66., 61., + 56., 51., 46., 41., 36., 31., 26., 21., 16., 11., 6., 1., -4., -9., -14., -19., -24., -29., -34., -39., -44., -49., -54., -59., -64., -69., -74., -79., -84., -89., -94., -99.}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oC, oH, oW}, {-14128., -21007., -20934., -20861., -13660., -12972., -12926.000977, -12880., -13468., -12788., -12742., -12696.000977, + -13276., -12604., -12558., -12512., -13408., -19569.5, -19501.5, -19433.5, -12230., -10117., -10081.000977, -10045., -12058., -9973., -9937., -9901.000977, + -11886., -9829., -9793., -9757., -12688., -18132., -18069., -18006., -10800., -7262., -7236., -7210., -10648., -7157.999512, -7132., -7106., -10496., -7054., + -7027.999512, -7002., -11968., -16694.5, -16636.5, -16578.5, -9370., -4406.999023, -4391., -4375., -9238., -4343., -4326.999023, -4311., -9106., -4279., -4263., + -4246.999023, -11247.999023, -15257., -15204., -15151., -7940., -1551.999023, -1546., -1540., -7828., -1528.000977, -1521.999023, -1516., -7716., -1504., + -1498.000977, -1491.999023, -10527.999023, -13819.5, -13771.5, -13723.5, -6510., 1303.000977, 1299., 1295., -6418., 1286.999023, 1283.000977, 1279., -6326., + 1271., 1266.999023, 1263.000977, -9807.999023, -12382., -12339., -12296., -5080., 4158.000977, 4144., 4130., -5008., 4101.999023, 4088., 4074., -4936., 4046., 4031.999023, 4018., -9088., -10944.5, -10906.5, -10868.5, -3650., 7013., 6989., 6965., -3598., 6917., 6893., 6869., -3546., 6821., 6797., 6773., -8368., -9507., -9474., -9441., -2220., 9868., 9834., 9800., -2187.999512, 9732., 9698., 9664., -2156., 9596., 9562., 9528., -7648., -8069.5, -8041.5, -8013.499512, -790.000488, 12723., 12679., 12635., -777.999512, 12547., 12503., 12459., -766., 12371., 12327., 12283., -10208., -15167., -15094., -15021., -9820., -9292., -9246., -9200., -9628., -9108., -9062., -9016., -9436., -8924., -8878., -8832., -9687.999023, -14129.5, -14061.5, -13993.5, -8790., -7236.999023, -7201., -7164.999512, -8618., -7093., -7057., -7021., -8446., -6949., -6913., -6877., -9168., -13092., -13029., -12966., -7760., -5182., -5156., -5129.999512, -7608., -5078., -5052., -5026., -7456., -4974., -4948., -4922., -8648., -12054.5, -11996.5, -11938.5, -6730., -3127., -3111., -3095., -6598., -3063., -3047., -3031., -6465.999512, -2999., -2983.000488, -2967., -8128., -11017., -10964., -10911., -5700.000488, -1072., -1066., -1060., -5587.999512, -1048.000488, -1042., -1036., -5476., -1023.999512, -1018.000488, -1012., -7608., -9979.5, -9931.5, -9883.5, -4670.000488, 983., 979., 975., -4577.999512, 966.999512, 963., 959., -4486., 951.000488, 946.999512, 943., -7088., -8942., -8899., -8856., -3640.000488, 3038., 3024., 3010., -3567.999512, 2981.999512, 2968., 2954., -3496., 2926.000488, 2911.999512, 2898., -6568., -7904.5, -7866.5, -7828.499512, -2610.000488, 5093., 5069., 5045., -2557.999512, 4996.999512, 4973., 4949., -2506., 4901.000488, 4877., 4853., -6048., -6867., -6834., -6800.999512, -1580., 7148., 7114., 7080., -1547.999512, 7012., 6978., 6944., -1516., 6876.000488, 6842., 6808., -5528., -5829.5, -5801.5, -5773.499512, -550., 9203., 9159., 9115., -537.999512, 9027., 8983., 8939., -526., 8851., 8807., 8763.}, sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 2c3351175..169c51124 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -595,6 +595,74 @@ TEST_F(ConvolutionTests2, deconv3d_test5) { } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test6) { + + int bS=2, oD=4,oH=4,oW=4, oC=5,iC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int iD=3,iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kD, kH, kW}, {20., 15., 10., 5., 0., -5., -10., -15., 19., 14., 9., 4., -1., -6., -11., -16., 18., 13., 8., 3., -2., -7., -12., -17., + 17., 12., 7., 2., -3., -8., -13., -18., 16., 11., 6., 1., -4., -9., -14., -19., 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, 18.9, 13.9, 8.9, 3.9, -1.1, -6.1, + -11.1, -16.1, 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1, + 19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2, + -17.200001, 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oD, oH, oW, oC}, {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, -8859.700195, -8338.700195, -7817.700195, + -7296.700195, -6775.700195, -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, -3572.850098, -3327.349854, -3081.850098, -2836.350098, + -2590.850098, -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902, + -11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, -6268.200195, + -5827.200195, -5386.200195, -4945.200195, -4504.200195, -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, -9558.400391, -8736.400391, + -7914.400391, -7092.399902, -6270.400391, -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, sd::DataType::FLOAT32); + + input.linspace(-27, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-3)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test7) { + + int bS=2, oD=4,oH=4,oW=4, iC=5,oC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=0,pW=0, dD=1,dH=1,dW=1; + int iD=4,iH=4,iW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, kD, kH, kW, oC}, {20., 19.5, 19., 18.5, 18., 17.5, 17., 16.5, 16., 15.5, 15., 14.5, 14., 13.5, 13., 12.5, 12., 11.5, 11., 10.5, 10., + 9.5, 9., 8.5, 8., 7.5, 7., 6.5, 6., 5.5, 5., 4.5, 4., 3.5, 3., 2.5, 2., 1.5, 1., 0.5, 0., -0.5, -1., -1.5, -2., -2.5, -3., -3.5, -4., -4.5, -5., -5.5, -6., + -6.5, -7., -7.5, -8., -8.5, -9., -9.5, -10., -10.5, -11., -11.5, -12., -12.5, -13., -13.5, -14., -14.5, -15., -15.5, -16., -16.5, -17., -17.5, -18., -18.5, + -19., -19.5, 19.9, 19.4, 18.9, 18.4, 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, 8.9, 8.4, 7.9, + 7.4, 6.9, 6.4, 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, -18.4, -18.9, -19.4, -19.9}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, -4664.800293, -4640.199707, -4615.600098, + -2755.599854, -4566.400391, -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513., + -1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, -1273.200195, -1263.999878, + -1579.200073, -2308.599854, -2294., -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, -1400.800049, -1172., -1162.800049, -1153.600098, + -1362.399902, -1135.199951, -1126., -1116.799805, -1422.400024, -2075., -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, -1043.200195, + -1247.199951, -1024.800049, -1015.599976, -1006.400146, -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, -957., -949.800049, -1342., -935.399902, -928.199951, -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, -849., -841.800171, -834.599976, -1204.400024, -820.199707, -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, -705., -697.800171, -690.599976, -1811.199951, -3133., -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, -2514.399902, -4105., -4082.400146, -4059.800293, -1552., -2175.200195, -2162.599854, -2150., -1228.800049, -630., -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, -598.800049, -1167.999878, -588.400391, -583.199951, -578., -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, -520.800049, -515.599976, -1046.400146, -505.199829, -500., -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, -260.599854, -1020.400146, -254.199829, -251., -247.799927, -994., -241.400269, -238.200073, -234.999878, -1327.200073, -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, -148.599976, -145.400024, -782.799927, -139., -135.799805, -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, -842.39978, 100.799927, 102.000244, 103.200439, -820., 105.599609, 106.800171, 108., -1243.199951, -1638.599854, -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, 117.60022, -752.799805, 120., 121.200073, 122.400024, -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, 141.599731, -640.799988, 144., 145.200195, 146.400146, -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, -590.799927, 443., 442.200073, 441.399658, -572.39978, 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, -1184., -1441.199951, -1432.599854, -1424., -500.799927, 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, 808.000244, 805.200073, -472., 799.60022, 796.799683, 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, 766., 763.200317, 760.400024, -414.400146, 754.800049, 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, 3453.200195, 1400., 2129.800049, 2144.400146, 2159., 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, 1624.800171, 1634., 1643.199951, 1556., 1661.600098, 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, 1347., 1354.199951, 1410., 1368.600098, 1375.800171, 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, 1901.599976, 3127., 3149.599854, 3172.200195, 1264., 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, 763.400024, 1091.599976, 769.800171, 773., 776.199951, 1118., 782.599976, 785.800049, 789., 1328.800049, 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, 811.400024, 814.60022, 1197.199951, 821., 824.199951, 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, 875.400146, 878.599854, 1329.199951, 885., 888.199951, 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, 1605., 927.200012, 480., 481.199951, 482.400146, 949.599976, 484.800171, 486., 487.200073, 971.999878, 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, 501.60022, 1039.199951, 504., 505.199951, 506.400146, 1061.599976, 508.799927, 510., 511.200195, 1377.599976, 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, 525.600098, 1151.199829, 528., 529.199829, 530.400146, 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, 199.800171, 199., 198.200195, 826., 196.599731, 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, 881.199951, 187., 186.199829, 185.400024, 899.60022, 183.800171, 183., 182.200073, 1293.599976, 1754.499878, 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, 973.200073, 171., 170.200073, 169.400146, 1068.800049, 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, -74., -76.799805, -79.599854, 665.600098, -85.199829, -87.999756, -90.799805, 680., -96.400024, -99.199829, -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, -130., -132.800171, -135.599976, 737.599976, -141.200073, -144., -146.799805, 1209.599976, 1586., 1594.600098, 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, -379.799805, 534., -389.400146, -394.19989, -398.999817, 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, 1044.400146, 1051., 375.199951, -627.999756, -634.800171, -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, 239.599976, -940.199707, -948.999817, -957.800171, 242., -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test1) { @@ -738,6 +806,96 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test4) { ASSERT_TRUE(expGradW.equalsTo(gradW)); } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test5) { + + int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, oC, kD, kH, kW}, {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW},sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iD, iH, iW}, {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, 9.552, + 9.540001, 9.528, 9.516, 9.504001, 9.492, 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, 9.335999, + 9.324001, 9.312, 9.300001, 9.288001, 9.276001, 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, 13.152, 13.134001, + 13.116, 13.098, 13.080001, 13.062, 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, 3.608, 3.604, + 3.6, 3.596, 3.592, 3.588, 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kD, kH, kW}, {-73678.695312, -59907.972656, -67739.515625, -54962.082031, -15966.075195, -17115.042969, -15269.777344, -16101.275391, 41746.566406, 25677.917969, 37200.003906, 22759.517578}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-16, 0.02); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test6) { + + int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=5,oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, kD, kH, kW, oC}, {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iD, iH, iW, iC}, {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996, + 0.462, -0.072, 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, 0.912, 0.434, + -0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, 0.84, 0.41, -0.02, 0.828, 0.406, -0.016, + 0.816, 0.402, -0.012, 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, 0.732, 0.374, + 0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kD, kH, kW, oC}, {-6328.958984, -6322.880371, -6134.400879, -6128.319824, -6318.079590, -6312.640137, -6144.000000, -6138.560547, -6307.202637, -6302.399414, -6153.599609, -6148.799316}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-1.6, 0.01); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_1) { @@ -2230,7 +2388,6 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_6) { ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.equalsTo(output)); - } ////////////////////////////////////////////////////////////////////// @@ -2263,7 +2420,6 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_7) { ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.equalsTo(output)); - } ////////////////////////////////////////////////////////////////////// @@ -2285,8 +2441,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_8) { 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, - 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, - 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); input.linspace(-10, 0.1); @@ -2341,4 +2496,350 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_9) { } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_10) { + + int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); + NDArray weights('c', {mC, iC, kH, kW}, {0.130845, 0.569885, 0.644284, 0.198968}, sd::DataType::FLOAT32); + NDArray biases('c', {iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, + 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, + 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_11) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {mC, kH, kW, iC}, {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, + 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., + 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, + 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, + 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, + 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, + 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, + 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, + 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test1) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int oC=iC*mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, + 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test2) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, + 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test3) { + + auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); + auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); + auto b = NDArrayFactory::create('c', {1, 16}); + auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); + + auto gradI = in.like(); + auto gradW = w.like(); + auto gradB = b.like(); + + nd4j:ops::depthwise_conv2d_bp op; + auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test4) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + + NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008, + 93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, + -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, + -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001, + -408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002, + -788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, + -107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062, + -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500, + -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000, + -143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* gradI = results.at(0); + NDArray* gradW = results.at(1); + NDArray* gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test5) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + + NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000, + -60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, + -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912, + -477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012, + -730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980, + -1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951, + -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625, + -302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* gradI = results.at(0); + NDArray* gradW = results.at(1); + NDArray* gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test6) { + + int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, + 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, + 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); + + auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test7) { + + int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {mC, iC, kH, kW}, {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {3,4}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + + NDArray expGradI('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, + 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, + 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {mC, iC, kH, kW}, {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, sd::DataType::FLOAT32); + + input = 2.; + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} + #endif //LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file From b23ebee43293d54a54d7666bb23265fb3083730c Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 20 Mar 2020 12:42:29 +0300 Subject: [PATCH 07/17] - MKL-DNN version upgrade - deviceMutex replaced for CPU Signed-off-by: raver119 --- libnd4j/CMakeLists.txt.mkldnn.in | 2 +- libnd4j/include/execution/cpu/LaunchContext.cpp | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index 4e4a130e1..36c426053 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -5,7 +5,7 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn GIT_REPOSITORY https://github.com/intel/mkl-dnn.git - GIT_TAG v1.2.1 + GIT_TAG v1.2.2 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" CONFIGURE_COMMAND "" diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 23df9f9f1..23e78c350 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -44,6 +44,7 @@ namespace sd { std::vector> LaunchContext::_contexts = std::vector>(); MAP_IMPL LaunchContext::_deviceMutexes; + std::mutex LaunchContext::_mutex; //////////////////////////////////////////////////////////////////////// LaunchContext::LaunchContext() { @@ -51,8 +52,6 @@ namespace sd { _workspace = nullptr; _deviceID = 0; - _deviceMutexes[_deviceID] = new std::mutex(); - #ifdef HAVE_MKLDNN _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); #endif @@ -73,8 +72,7 @@ namespace sd { } std::mutex* LaunchContext::deviceMutex() { - auto deviceId = AffinityManager::currentDeviceId(); - return _deviceMutexes[deviceId]; + return &_mutex; } void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { From f79207033b06322ab35bd6000e43a00d472fb9e8 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 20 Mar 2020 21:24:39 +1100 Subject: [PATCH 08/17] SameDiff multi-threaded inference (#263) * #8682 Don't log openmp BLAS threads for CUDA Signed-off-by: Alex Black * #8654 Add SameDiff multi-threaded tests Signed-off-by: Alex Black * Switching to op context for SameDiff exec Signed-off-by: Alex Black * Next steps Signed-off-by: Alex Black * Most back to passing Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * Better tests, test refactoring Signed-off-by: Alex Black * Small tweak Signed-off-by: Alex Black * Code duplication reduction Signed-off-by: Alex Black * More code deduplication Signed-off-by: Alex Black * CUDA fixes Signed-off-by: Alex Black * More CUDA fixes Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * Small fix Signed-off-by: Alex Black * ND4S small fixes Signed-off-by: Alex Black --- .../functions/DifferentialFunction.java | 5 + .../nd4j/autodiff/listeners/BaseListener.java | 5 +- .../org/nd4j/autodiff/listeners/Listener.java | 7 +- .../debugging/ArraySavingListener.java | 3 +- .../debugging/ExecDebuggingListener.java | 3 +- .../debugging/OpBenchmarkListener.java | 5 +- .../autodiff/listeners/impl/UIListener.java | 3 +- .../listeners/profiler/ProfilingListener.java | 5 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 20 - .../samediff/internal/InferenceSession.java | 78 +- .../samediff/internal/TrainingSession.java | 7 +- .../ActivationGradientCheckListener.java | 4 +- .../NonInplaceValidationListener.java | 21 +- .../linalg/api/ops/BaseIndexAccumulation.java | 6 + .../nd4j/linalg/api/ops/BaseOpContext.java | 39 + .../nd4j/linalg/api/ops/BaseReduceBoolOp.java | 24 +- .../linalg/api/ops/BaseReduceFloatOp.java | 34 +- .../nd4j/linalg/api/ops/BaseReduceLongOp.java | 24 +- .../nd4j/linalg/api/ops/BaseReduceSameOp.java | 30 +- .../nd4j/linalg/api/ops/BaseScalarBoolOp.java | 6 + .../org/nd4j/linalg/api/ops/BaseScalarOp.java | 7 + .../linalg/api/ops/BaseTransformAnyOp.java | 7 +- .../linalg/api/ops/BaseTransformBoolOp.java | 24 +- .../linalg/api/ops/BaseTransformFloatOp.java | 26 +- .../linalg/api/ops/BaseTransformSameOp.java | 28 +- .../linalg/api/ops/BaseTransformStrictOp.java | 25 +- .../org/nd4j/linalg/api/ops/CustomOp.java | 8 +- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 26 +- .../java/org/nd4j/linalg/api/ops/NoOp.java | 8 + .../org/nd4j/linalg/api/ops/OpContext.java | 14 +- .../org/nd4j/linalg/api/ops/ReduceOp.java | 4 +- .../org/nd4j/linalg/api/ops/TransformOp.java | 4 +- .../linalg/api/ops/custom/ScatterUpdate.java | 6 + .../ops/executioner/DefaultOpExecutioner.java | 78 +- .../api/ops/executioner/OpExecutioner.java | 9 + .../impl/layers/ExternalErrorsFunction.java | 6 + .../api/ops/impl/summarystats/Variance.java | 40 +- .../api/ops/impl/transforms/MaxOut.java | 18 +- .../linalg/api/ops/random/BaseRandomOp.java | 10 + .../ops/random/impl/BinomialDistribution.java | 5 + .../ops/random/impl/GaussianDistribution.java | 5 + .../random/impl/LogNormalDistribution.java | 5 + .../impl/TruncatedNormalDistribution.java | 5 + .../java/org/nd4j/linalg/factory/Nd4j.java | 4 + .../java/org/nd4j/nativeblas/Nd4jBlas.java | 6 +- .../nd4j/linalg/jcublas/blas/CudaBlas.java | 5 + .../ops/executioner/CudaExecutioner.java | 522 ++++++++------ .../ops/executioner/CudaGridExecutioner.java | 20 +- .../nativecpu/ops/NativeOpExecutioner.java | 665 ++++++++++-------- .../opvalidation/RandomOpValidation.java | 1 + .../samediff/SameDiffMultiThreadTests.java | 169 +++++ .../nd4j/autodiff/samediff/SameDiffTests.java | 4 + .../samediff/listeners/ListenerTest.java | 5 +- .../listener/OpExecOrderListener.java | 4 +- .../imports/listeners/ExecPrintListener.java | 3 +- .../listeners/ImportDebugListener.java | 3 +- .../org/nd4j/linalg/NDArrayTestsFortran.java | 3 +- .../nd4s/ops/FunctionalOpExecutioner.scala | 6 + 58 files changed, 1426 insertions(+), 691 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 655e4159f..94bda0b78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -31,6 +31,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.shade.jackson.annotation.JsonIgnore; @@ -708,6 +709,10 @@ public abstract class DifferentialFunction { throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]"); } + public List calculateOutputShape(OpContext oc){ + throw new ND4JIllegalStateException("calculateOutputShape(OpContext) method leaked out for [" + this.opName() + "]"); + } + /** * Calculate the data types for the output arrays. * Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java index 61a5e75a3..6978a79d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java @@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; /** @@ -60,12 +61,12 @@ public abstract class BaseListener implements Listener { } @Override - public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) { //No op } @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { //No op } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java index 18e3b934b..4ed7df6c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java @@ -5,6 +5,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; /** @@ -104,7 +105,7 @@ public interface Listener { * @param at Current iteration/epoch etc * @param op Operation that has just been executed */ - void preOpExecution(SameDiff sd, At at, SameDiffOp op); + void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext); /** * Called at the end of each operation execution
@@ -117,7 +118,7 @@ public interface Listener { * @param op Operation that has just been executed * @param outputs The output arrays for the just-executed operation */ - void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs); + void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs); /** * Called when any activation becomes available. @@ -127,7 +128,7 @@ public interface Listener { * Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
* It is guaranteed to be called for variables from requiredVariables().
*
- * Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} - + * Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, OpContext, INDArray[])} - * both contain the same information/arrays * * @param sd The SameDiff instance diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java index 6b64c69d8..9770dc50c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java @@ -9,6 +9,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -44,7 +45,7 @@ public class ArraySavingListener extends BaseListener { @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { List outNames = op.getOutputsOfOp(); for(int i=0; i this.minRuntime) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index 6c38c6c9c..b4f3a371b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -19,6 +19,7 @@ import org.nd4j.graph.UIInfoType; import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.primitives.Pair; @@ -410,7 +411,7 @@ public class UIListener extends BaseListener { @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { //Do training set evaluation, if required diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java index 9b92b0412..3dc21876e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java @@ -30,6 +30,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -192,7 +193,7 @@ public class ProfilingListener extends BaseListener { } @Override - public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext opContext) { if (logActive) { opStartNano = System.nanoTime(); @@ -202,7 +203,7 @@ public class ProfilingListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { if (logActive) { long now = System.nanoTime(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index de421b297..7ca809b2d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -105,7 +105,6 @@ import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} */ @AllArgsConstructor -@Builder @Slf4j public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; @@ -1232,25 +1231,6 @@ public class SameDiff extends SDBaseOps { return result; } - - /** - * Create a new SameDiff instance from an existing instance. - * Note that state (variables and functions) is shared between the two SameDiff instance - * - * @param originalSameDiff Original SameDiff instance - * @return Copy - */ - public static SameDiff create(SameDiff originalSameDiff) { - SameDiff ret = SameDiff.builder() - .sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances) - .build(); - ret.variables.putAll(originalSameDiff.variables); - //ensuring proper sameDiff reference - DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret); - ret.functionFactory = differentialFunctionFactory; - return ret; - } - @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 9b8d751eb..26bf82893 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -18,6 +18,7 @@ package org.nd4j.autodiff.samediff.internal; import lombok.*; import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.Pointer; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Listener; @@ -46,6 +47,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import java.util.*; @@ -65,7 +67,7 @@ import java.util.*; * @author Alex Black */ @Slf4j -public class InferenceSession extends AbstractSession { +public class InferenceSession extends AbstractSession> { private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" + "Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed."; @@ -83,6 +85,8 @@ public class InferenceSession extends AbstractSession { private IdentityDependencyTracker arrayUseTracker = new IdentityDependencyTracker<>(); + private Map opContexts = new HashMap<>(); + public InferenceSession(@NonNull SameDiff sameDiff) { super(sameDiff); mmgr = new ArrayCacheMemoryMgr(); @@ -204,18 +208,19 @@ public class InferenceSession extends AbstractSession { } @Override - public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + public INDArray[] getOutputs(Pair opPair, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { + SameDiffOp op = opPair.getFirst(); at.setFrameIter(outputFrameIter); if (listeners != null && listeners.size() > 0) { SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName()); for (Listener l : listeners) { if (l.isActive(at.operation())) - l.preOpExecution(sameDiff, at, sdOp); + l.preOpExecution(sameDiff, at, sdOp, opPair.getSecond()); } } - INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); + INDArray[] out = doExec(op.getOp(), opPair.getRight(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); if (log.isTraceEnabled()) { StringBuilder sb = new StringBuilder(); @@ -246,7 +251,7 @@ public class InferenceSession extends AbstractSession { } - l.opExecution(sameDiff, at, batch, op, out); + l.opExecution(sameDiff, at, batch, op, opPair.getSecond(), out); for (String varName : namedOuts.keySet()) { l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName)); @@ -255,6 +260,8 @@ public class InferenceSession extends AbstractSession { } } op.getOp().clearArrays(); + if(opPair.getSecond() != null) + opPair.getSecond().purge(); //Record array uses for memory management/deallocation @@ -343,7 +350,7 @@ public class InferenceSession extends AbstractSession { return out; } - public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + public INDArray[] doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs) { int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) @@ -467,31 +474,31 @@ public class InferenceSession extends AbstractSession { return new INDArray[]{out}; } else if (op instanceof Assert) { Assert a = (Assert)op; - boolean condition = a.getInputArgument(0).getDouble(0) != 0.0; + boolean condition = opContext.getInputArray(0).getDouble(0) != 0.0; if(!condition){ //Assertion failed String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution"; if(a.numInputArguments() >= 3) { - INDArray msg = a.getInputArgument(2); + INDArray msg = opContext.getInputArray(2); if (msg != null && msg.dataType() == DataType.UTF8) { s += ": " + msg.getString(0); } } if(a.numInputArguments() >= 5){ - INDArray arr = a.getInputArgument(4); + INDArray arr = opContext.getInputArray(4); s += "\n" + arr; } throw new IllegalStateException(s); } - return ((Assert) op).outputArguments().toArray(new INDArray[0]); + return opContext.getOutputArrays().toArray(new INDArray[0]); } else if (op instanceof CustomOp) { CustomOp c = (CustomOp) op; - Nd4j.exec(c); - return c.outputArguments().toArray(new INDArray[0]); + Nd4j.exec(c, opContext); + return opContext.getOutputArrays().toArray(new INDArray[0]); } else if (op instanceof Op) { Op o = (Op) op; - Nd4j.exec(o); - return new INDArray[]{o.z()}; + Nd4j.exec(o, opContext); + return new INDArray[]{opContext.getOutputArray(0)}; } else { throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName()); } @@ -774,7 +781,7 @@ public class InferenceSession extends AbstractSession { } @Override - public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set opInputs, Set allIterInputs, + public Pair getAndParameterizeOp(String opName, FrameIter frameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, Map placeholderValues, Set allReqVariables) { SameDiffOp sdo = sameDiff.getOps().get(opName); DifferentialFunction df = sdo.getOp(); @@ -786,7 +793,7 @@ public class InferenceSession extends AbstractSession { if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) { //Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case - return sdo; + return new Pair<>(sdo, null); } //Infer the args based on the inputs (variable + frame + iteration) @@ -839,24 +846,39 @@ public class InferenceSession extends AbstractSession { //TODO let's find a way to use in-place modification for loops where possible to reduce memory requirements boolean isLoop = !frameIter.getFrame().equals(OUTER_FRAME) && frameIter.getIteration() > 0; + OpContext oc = opContexts.get(opName); + if(oc == null){ + oc = Nd4j.getExecutioner().buildContext(); + opContexts.put(opName, oc); + } + if (df instanceof CustomOp) { DynamicCustomOp customOp = (DynamicCustomOp) df; if (args != null) { - customOp.setInputArguments(args); + oc.setInputArrays(args); } if (df instanceof Identity) { //We don't need to allocate an output array for Identity, we pass through the input array without copying - return sdo; + return new Pair<>(sdo, oc); } - List outShape = customOp.calculateOutputShape(); + if(customOp.numIArguments() > 0) + oc.setIArguments(customOp.iArgs()); + if(customOp.numDArguments() > 0) + oc.setDArguments(customOp.dArgs()); + if(customOp.numTArguments() > 0) + oc.setTArguments(customOp.tArgs()); + if(customOp.numBArguments() > 0) + oc.setBArguments(customOp.bArgs()); + + + List outShape = customOp.calculateOutputShape(oc); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); String[] outNames = df.outputVariablesNames(); Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" + " with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length); for (int i = 0; i < outShape.size(); i++) { - INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i)); LongShapeDescriptor reqShape = outShape.get(i); //Issue: many ops have multiple valid output datatypes, and output shape calc can't at present know which: https://github.com/deeplearning4j/deeplearning4j/issues/6872 @@ -870,7 +892,7 @@ public class InferenceSession extends AbstractSession { //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc boolean isOutput = allReqVariables.contains(outNames[i]); INDArray out = mmgr.allocate(isOutput, reqShape); - customOp.setOutputArgument(i, out); + oc.setOutputArray(i, out); } } else if (df instanceof Op) { @@ -909,9 +931,9 @@ public class InferenceSession extends AbstractSession { } if (args != null && args.length > 0) { - op.setX(args[0]); + oc.setInputArray(0, args[0]); if (args.length == 2 && !axisArg) - op.setY(args[1]); + oc.setInputArray(1, args[1]); } @@ -920,18 +942,18 @@ public class InferenceSession extends AbstractSession { boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); if (emptyReduce) { //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc - INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape()); - op.setZ(z); + INDArray z = mmgr.allocate(false, oc.getInputArray(0).dataType(), oc.getInputArray(0).shape()); + oc.setOutputArray(0, z); } else { - List outputShape = ((BaseOp) op).calculateOutputShape(); + List outputShape = ((BaseOp) op).calculateOutputShape(oc); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); LongShapeDescriptor lsd = outputShape.get(0); INDArray z = mmgr.allocate(isOutput, lsd); - op.setZ(z); + oc.setOutputArray(0, z); } } - return sdo; + return new Pair<>(sdo, oc); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java index 992a747a0..e683acc47 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -11,10 +11,12 @@ import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.primitives.AtomicDouble; +import org.nd4j.linalg.primitives.Pair; import java.util.*; @@ -135,10 +137,11 @@ public class TrainingSession extends InferenceSession { } @Override - public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, + public INDArray[] getOutputs(Pair opPair, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { //Get outputs from InferenceSession - INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); + INDArray[] out = super.getOutputs(opPair, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); + SameDiffOp op = opPair.getFirst(); List outputs = op.getOutputsOfOp(); int outIdx = 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java index a8865f972..d1137746d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java @@ -12,6 +12,8 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; + +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; /** @@ -36,7 +38,7 @@ public class ActivationGradientCheckListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener"); Preconditions.checkState(eps != 0.0, "Epsilon has not been set"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index 9eee099a5..3d28b29fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -14,7 +14,10 @@ import org.nd4j.linalg.api.ops.Op; import java.security.MessageDigest; import java.util.Arrays; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; + +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.dataset.api.MultiDataSet; public class NonInplaceValidationListener extends BaseListener { @@ -33,25 +36,25 @@ public class NonInplaceValidationListener extends BaseListener { } @Override - public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op, OpContext oc) { if(op.getOp().isInPlace()){ //Don't check inplace op return; } if(op.getOp() instanceof Op){ Op o = (Op)op.getOp(); - if(o.x() == null){ + if(oc.getInputArray(0) == null){ //No input op return; - } else if(o.y() == null){ - opInputsOrig = new INDArray[]{o.x()}; - opInputs = new INDArray[]{o.x().dup()}; + } else if(oc.getInputArray(1) == null){ + opInputsOrig = new INDArray[]{oc.getInputArray(0)}; + opInputs = new INDArray[]{oc.getInputArray(0).dup()}; } else { - opInputsOrig = new INDArray[]{o.x(), o.y()}; - opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; + opInputsOrig = new INDArray[]{oc.getInputArray(0), oc.getInputArray(1)}; + opInputs = new INDArray[]{oc.getInputArray(0).dup(), oc.getInputArray(1).dup()}; } } else if(op.getOp() instanceof DynamicCustomOp){ - val arr = ((DynamicCustomOp) op.getOp()).inputArguments(); + List arr = oc.getInputArrays(); // ((DynamicCustomOp) op.getOp()).inputArguments(); opInputs = new INDArray[arr.size()]; opInputsOrig = new INDArray[arr.size()]; for( int i=0; i calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc){ + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index 0139a9db5..c7d71db04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -55,6 +55,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_i; } + @Override + public int numIArguments() { + return fastpath_i.size(); + } + @Override public void setTArguments(double... arguments) { fastpath_t.clear(); @@ -67,6 +72,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_t; } + @Override + public int numTArguments() { + return fastpath_t.size(); + } + @Override public void setBArguments(boolean... arguments) { fastpath_b.clear(); @@ -79,6 +89,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_b; } + @Override + public int numBArguments() { + return fastpath_b.size(); + } + @Override public void setDArguments(DataType... arguments) { fastpath_d.clear(); @@ -91,6 +106,11 @@ public abstract class BaseOpContext implements OpContext { return fastpath_d; } + @Override + public int numDArguments() { + return fastpath_d.size(); + } + @Override public void setInputArray(int index, @NonNull INDArray array) { fastpath_in.put(index, array); @@ -110,6 +130,16 @@ public abstract class BaseOpContext implements OpContext { return result; } + @Override + public int numInputArguments() { + return fastpath_in.size(); + } + + @Override + public INDArray getInputArray(int idx) { + return fastpath_in.get(idx); + } + @Override public List getOutputArrays() { val result = new ArrayList(); @@ -129,6 +159,15 @@ public abstract class BaseOpContext implements OpContext { fastpath_out.put(index, array); } + @Override + public INDArray getOutputArray(int i) { + return fastpath_out.get(i); + } + + @Override + public int numOutputArguments() { + return fastpath_out.size(); + } @Override public void setInputArrays(@NonNull List arrays) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java index dd2072758..af022c86f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java @@ -72,19 +72,33 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y:" + + public DataType resultType(OpContext oc) { + return DataType.BOOL; + } + + @Override + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y:" + " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); - if (z() != null) - Preconditions.checkArgument(z().isB(), "Op.X type must be bool: got type %s for op %s", x.dataType(), getClass()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument(z.isB(), "Op.Z type must be bool: got type %s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java index 6f3722011..29860aee9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java @@ -90,27 +90,43 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl @Override public DataType resultType() { - if (this.x() != null && this.x().isR()) - return this.x().dataType(); + return resultType(null); + } + + @Override + public DataType resultType(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if (x != null && x.isR()) + return x.dataType(); return Nd4j.defaultFloatingPointType(); } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(), - "Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x().dataType(), - y().dataType(), getClass().getName(), x(), y() ); + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(), + "Op.X [%s] type must be the same as Op.Y [%s] for op %s: x.shape=%ndShape, y.shape=%ndShape", x.dataType(), + y.dataType(), getClass().getName(), x, y ); - if (z() != null) - Preconditions.checkArgument(z().isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z().dataType()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument(z.isR(),"Op.Z (result array) must be one of floating types: z datatype = %s", z.dataType()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java index 9f82bb6b4..b5131eb61 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java @@ -69,19 +69,33 @@ public abstract class BaseReduceLongOp extends BaseReduceOp implements ReduceLon } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y:" + + public DataType resultType(OpContext oc) { + return DataType.LONG; + } + + @Override + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y:" + " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); - if (z() != null) - Preconditions.checkArgument( z().dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z().dataType(), getClass()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument( z.dataType() == DataType.LONG,"Op.Z must be long: has type %s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java index 0aa4460c3..015b87b5d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java @@ -77,26 +77,42 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam } @Override - public boolean validateDataTypes() { - if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(),"Op.X type must be the same as Op.Y type:" + + public DataType resultType(OpContext oc){ + return oc.getInputArray(0).dataType(); + } + + @Override + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null) + Preconditions.checkArgument(x.dataType() == y.dataType(),"Op.X type must be the same as Op.Y type:" + " x.dataType=%s, y.dataType=%s, op=%s", x.dataType(), y.dataType(), getClass().getName()); - if (z() != null) - Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " + - "Op.Z.datatype=%s", x().dataType(), z.dataType()); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null) + Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type. Op.X.datatype=%s, " + + "Op.Z.datatype=%s", x.dataType(), z.dataType()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if(x == null) return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); - return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, this.resultType())); + DataType rt = oc != null ? resultType(oc) : resultType(); + return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, rt)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 082465cbe..8cb7e50b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -98,6 +98,12 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index e6df6ceec..254069929 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -115,6 +115,13 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + val ret = new ArrayList(1); long[] s; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java index 71749bdda..7aa24a60b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformAnyOp.java @@ -89,7 +89,12 @@ public abstract class BaseTransformAnyOp extends BaseTransformOp implements Tran } @Override - public boolean validateDataTypes(boolean experimentalMode) { + public DataType resultType(OpContext oc) { + return oc.getInputArray(0).dataType(); + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { return true; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java index fd19f23d0..d4e69db5d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformBoolOp.java @@ -88,20 +88,34 @@ public abstract class BaseTransformBoolOp extends BaseTransformOp implements Tra } @Override - public boolean validateDataTypes(boolean experimentalMode) { + public DataType resultType(OpContext oc) { + return DataType.BOOL; + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); if (y() != null) - Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X must be the same type as Op.Y: " + - "x.datatype=%s, y.datatype=%s", x().dataType(), y.dataType()); + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must be the same type as Op.Y: " + + "x.datatype=%s, y.datatype=%s", x.dataType(), y.dataType()); - if (z() != null) - Preconditions.checkArgument(z().isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z().dataType(), getClass()); + if (z != null) + Preconditions.checkArgument(z.isB(),"Op.Z type must be bool: z.datatype=%s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), DataType.BOOL)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java index 12516577c..42e2ef278 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformFloatOp.java @@ -72,19 +72,37 @@ public abstract class BaseTransformFloatOp extends BaseTransformOp implements Tr } @Override - public boolean validateDataTypes(boolean experimentalMode) { - if (y() != null && !experimentalMode) { + public DataType resultType(OpContext oc) { + if (oc.getInputArray(0) != null && oc.getInputArray(0).isR()) + return oc.getInputArray(0).dataType(); + + return Nd4j.defaultFloatingPointType(); + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + + if (y != null && !experimentalMode) { Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y"); } - if (z() != null) - Preconditions.checkArgument(z().isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z().dataType(), getClass()); + if (z != null) + Preconditions.checkArgument(z.isR(),"Op.Z must be one of floating types: z.datatype=%s for op %s", z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.isR() ? x.dataType() : Nd4j.defaultFloatingPointType())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java index b04c24c8c..b7d0ff4ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java @@ -89,22 +89,36 @@ public abstract class BaseTransformSameOp extends BaseTransformOp implements Tra } @Override - public boolean validateDataTypes(boolean experimentalMode) { - if (y() != null) { - Preconditions.checkArgument(x().dataType() == y().dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s", - x().dataType(), y().dataType(), getClass()); + public DataType resultType(OpContext oc) { + return oc.getInputArray(0).dataType(); + } + + @Override + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (y != null) { + Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X type must be the same as Op.Y type: x.datatype=%s, y.datatype=%s for op %s", + x.dataType(), y.dataType(), getClass()); } - if (z() != null) - Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s", - x().dataType(), z.dataType(), getClass()); + if (z != null) + Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must be the same as Op.X type: x.datatype=%s, z.datatype=%s for op %s", + x.dataType(), z.dataType(), getClass()); return true; } @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java index d2a4dccc3..963138880 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformStrictOp.java @@ -76,20 +76,28 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T return this.x().dataType(); } + @Override + public DataType resultType(OpContext opContext) { + return opContext.getInputArray(0).dataType(); + } + @Override - public boolean validateDataTypes(boolean experimentalMode) { - Preconditions.checkArgument(x().isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x().dataType(), getClass()); + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + Preconditions.checkArgument(x.isR(), "Op.X must be one of floating types: x.datatype=%s for op %s", x.dataType(), getClass()); - if (y() != null) { - Preconditions.checkArgument(y().isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y().dataType(), getClass()); + if (y != null) { + Preconditions.checkArgument(y.isR(), "Op.Y must be one of floating types: y.datatype=%s for op %s", y.dataType(), getClass()); if (!experimentalMode) Preconditions.checkArgument(x.dataType() == y.dataType(), "Op.X must have same data type as Op.Y"); } if (z() != null) - Preconditions.checkArgument(z().dataType() == x().dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s", + Preconditions.checkArgument(z.dataType() == x.dataType(), "Op.Z must have the same type as Op.X: x.datatype=%s, z.datatype=%s for op %s", x.dataType(), z.dataType(), getClass()); return true; @@ -102,6 +110,13 @@ public abstract class BaseTransformStrictOp extends BaseTransformOp implements T return Collections.singletonList(LongShapeDescriptor.fromShape(x.shape(), x.dataType())); } + @Override + public List calculateOutputShape(OpContext oc) { + if(oc.getInputArray(0) == null) + return Collections.emptyList(); + return Collections.singletonList(LongShapeDescriptor.fromShape(oc.getInputArray(0).shape(), oc.getInputArray(0).dataType())); + } + @Override public List calculateOutputDataTypes(List dataTypes){ //All strict tranform ops: FP in, FP out diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java index befdfb605..cdf8e3b36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java @@ -108,10 +108,16 @@ public interface CustomOp { /** * Calculate the output shape for this op - * @return + * @return Output array shapes */ List calculateOutputShape(); + /** + * Calculate the output shape for this op + * @return Output array shapes + */ + List calculateOutputShape(OpContext opContext); + /** * Get the custom op descriptor if one is available. * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index f4116ba3e..3fe90bdbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -493,6 +493,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { val descriptor = getDescriptor(); if (outputShapes != null && !outputShapes.isEmpty()) return outputShapes; @@ -504,34 +509,41 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { //not fully initialized: missing integer args - if (descriptor.getNumIArgs() >= 0 && numIArguments() < descriptor.getNumIArgs()) { + int nI = oc != null ? oc.numIArguments() : numIArguments(); + if (descriptor.getNumIArgs() >= 0 && nI < descriptor.getNumIArgs()) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: not fully initialized ({} IArgs specified, " + - "{} required)", getClass().getName(),numIArguments(), descriptor.getNumIArgs()); + "{} required)", getClass().getName(), nI, descriptor.getNumIArgs()); } return Collections.emptyList(); } //not fully initialized: missing floating point args - if (descriptor.getNumTArgs() >= 0 && numTArguments() < descriptor.getNumTArgs()) { + int nT = oc != null ? oc.numTArguments() : numTArguments(); + if (descriptor.getNumTArgs() >= 0 && nT < descriptor.getNumTArgs()) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: not fully initialized ({} TArgs specified, " + - "{} required)", getClass().getName(),numTArguments(), descriptor.getNumTArgs()); + "{} required)", getClass().getName(), nT, descriptor.getNumTArgs()); } return Collections.emptyList(); } //not fully initialized: missing INDArray input args - if(descriptor.getNumInputs() >= 0 && numInputArguments() < descriptor.getNumInputs()){ + int nIn = oc != null ? oc.numInputArguments() : numInputArguments(); + if(descriptor.getNumInputs() >= 0 && nIn < descriptor.getNumInputs()){ if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: not fully initialized ({} input (INDArray) args specified, " + - "{} required)", getClass().getName(),numInputArguments(), descriptor.getNumInputs()); + "{} required)", getClass().getName(), nIn, descriptor.getNumInputs()); } return Collections.emptyList(); } - List ret = Nd4j.getExecutioner().calculateOutputShape(this); + List ret; + if(oc == null) + ret = Nd4j.getExecutioner().calculateOutputShape(this); + else + ret = Nd4j.getExecutioner().calculateOutputShape(this, oc); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index 554ad917e..b4cf2d05a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -89,6 +89,14 @@ public class NoOp extends DynamicCustomOp { return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor()); } + @Override + public List calculateOutputShape(OpContext oc){ + if(oc.getInputArrays() != null && !oc.getInputArrays().isEmpty()){ + return Collections.singletonList(oc.getInputArray(0).shapeDescriptor()); + } + return Collections.singletonList(Nd4j.empty(DataType.BOOL).shapeDescriptor()); + } + @Override public List calculateOutputDataTypes(List inputDataTypes){ return Collections.singletonList(DataType.BOOL); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index 62a4906a7..4bda3701e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -39,12 +39,15 @@ public interface OpContext extends AutoCloseable { List getIArguments(); + int numIArguments(); + /** * This method sets floating point arguments required for operation * @param arguments */ void setTArguments(double... arguments); List getTArguments(); + int numTArguments(); /** * This method sets data type arguments required for operation @@ -52,14 +55,15 @@ public interface OpContext extends AutoCloseable { */ void setDArguments(DataType... arguments); List getDArguments(); + int numDArguments(); /** * This method sets boolean arguments required for operation * @param arguments */ void setBArguments(boolean... arguments); - List getBArguments(); + int numBArguments(); /** * This method sets root-level seed for rng @@ -99,6 +103,10 @@ public interface OpContext extends AutoCloseable { */ List getInputArrays(); + int numInputArguments(); + + INDArray getInputArray(int idx); + /** * This method adds INDArray as output for future op call * @param index @@ -124,6 +132,10 @@ public interface OpContext extends AutoCloseable { */ List getOutputArrays(); + INDArray getOutputArray(int i); + + int numOutputArguments(); + /** * This method returns pointer to context, to be used during native op execution * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java index 8f1814dfd..23d81c5b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java @@ -86,7 +86,9 @@ public interface ReduceOp extends Op { */ DataType resultType(); - boolean validateDataTypes(); + DataType resultType(OpContext oc); + + boolean validateDataTypes(OpContext oc); Number getFinalResult(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java index 9c3f9b423..f50116d32 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/TransformOp.java @@ -31,7 +31,9 @@ public interface TransformOp extends Op { */ DataType resultType(); + DataType resultType(OpContext opContext); + Type getOpType(); - boolean validateDataTypes(boolean experimentalMode); + boolean validateDataTypes(OpContext opContext, boolean experimentalMode); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index 313b7ccb4..50c5db75e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOpDescriptor; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -237,6 +238,11 @@ public class ScatterUpdate implements CustomOp { return Nd4j.getExecutioner().calculateOutputShape(this); } + @Override + public List calculateOutputShape(OpContext opContext) { + return Nd4j.getExecutioner().calculateOutputShape(this, opContext); + } + @Override public CustomOpDescriptor getDescriptor() { return op.getDescriptor(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index aea251ebd..c60b11d23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -55,7 +55,7 @@ import java.util.*; * @author Adam Gibson */ @Slf4j -public class DefaultOpExecutioner implements OpExecutioner { +public abstract class DefaultOpExecutioner implements OpExecutioner { private static final String SCOPE_PANIC_MSG = "For more details, see the ND4J User Guide: deeplearning4j.org/docs/latest/nd4j-overview#workspaces-panic"; @@ -108,9 +108,10 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public INDArray exec(Op op) { - throw new IllegalStateException("Java computation no longer supported"); - } + public abstract INDArray exec(Op op); + + @Override + public abstract INDArray exec(Op op, OpContext opContext); @Override public Op execAndReturn(Op op) { @@ -175,24 +176,16 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public INDArray exec(ReduceOp op) { - throw new UnsupportedOperationException("Java computation no longer supported"); - } + public abstract INDArray exec(ReduceOp op); @Override - public INDArray exec(Variance accumulation) { - throw new UnsupportedOperationException("Operation should use exec special"); - } + public abstract INDArray exec(Variance accumulation); @Override - public INDArray exec(IndexAccumulation op) { - throw new UnsupportedOperationException("Operation should use exec special"); - } + public abstract INDArray exec(IndexAccumulation op); @Override - public INDArray exec(BroadcastOp broadcast) { - throw new IllegalStateException("Java computation no longer supported"); - } + public abstract INDArray exec(BroadcastOp broadcast); @Override public void exec(MetaOp op) { @@ -215,9 +208,7 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public INDArray exec(ScalarOp op) { - throw new UnsupportedOperationException(); - } + public abstract INDArray exec(ScalarOp op); @Override public void exec(List batch) { @@ -241,9 +232,7 @@ public class DefaultOpExecutioner implements OpExecutioner { * @param rng */ @Override - public INDArray exec(RandomOp op, Random rng) { - throw new UnsupportedOperationException(); - } + public abstract INDArray exec(RandomOp op, Random rng); @Deprecated @@ -741,6 +730,11 @@ public class DefaultOpExecutioner implements OpExecutioner { throw new UnsupportedOperationException(); } + @Override + public List calculateOutputShape(CustomOp op, OpContext opContext) { + throw new UnsupportedOperationException(); + } + @Override public INDArray[] allocateOutputArrays(CustomOp op){ List shapes = calculateOutputShape(op); @@ -946,4 +940,44 @@ public class DefaultOpExecutioner implements OpExecutioner { public String runFullBenchmarkSuit(boolean printOut) { throw new UnsupportedOperationException(); } + + + public void setX(INDArray x, Op op, OpContext oc){ + if(oc != null) + oc.setInputArray(0, x); + else + op.setX(x); + } + + public INDArray getX(Op op, OpContext oc){ + if( oc != null ) + return oc.getInputArray(0); + return op.x(); + } + + public void setY(INDArray y, Op op, OpContext oc){ + if(oc != null) + oc.setInputArray(1, y); + else + op.setY(y); + } + + public INDArray getY(Op op, OpContext oc){ + if( oc != null ) + return oc.getInputArray(1); + return op.y(); + } + + public void setZ(INDArray z, Op op, OpContext oc){ + if(oc != null) + oc.setOutputArray(0, z); + else + op.setZ(z); + } + + public INDArray getZ(Op op, OpContext oc){ + if( oc != null ) + return oc.getOutputArray(0); + return op.z(); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java index c4af57864..fcf8bfd3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java @@ -98,6 +98,13 @@ public interface OpExecutioner { */ INDArray exec(Op op); + /** + * Execute the operation + * + * @param op the operation to execute + */ + INDArray exec(Op op, OpContext opContext); + /**Execute a TransformOp and return the result * @param op the operation to execute */ @@ -364,6 +371,8 @@ public interface OpExecutioner { List calculateOutputShape(CustomOp op); + List calculateOutputShape(CustomOp op, OpContext opContext); + /** * Equivalent to calli */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java index 496943b45..8b8cc5f53 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -150,6 +151,11 @@ public class ExternalErrorsFunction extends DynamicCustomOp { return OUT_SHAPE; } + @Override + public List calculateOutputShape(OpContext oc){ + return OUT_SHAPE; + } + public Op.Type opType() { return Op.Type.LOGIC; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index 504012703..b44b11cf6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -131,8 +132,14 @@ public class Variance extends BaseReduceOp { @Override public DataType resultType() { - if (this.x() != null && this.x().isR()) - return this.x().dataType(); + return resultType(null); + } + + @Override + public DataType resultType(OpContext oc){ + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if (x != null && x.isR()) + return x.dataType(); if(this.arg() != null){ return this.arg().dataType(); @@ -142,14 +149,18 @@ public class Variance extends BaseReduceOp { } @Override - public boolean validateDataTypes() { - if (!x().isR()) + public boolean validateDataTypes(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + if (x != null && !x.isR()) { + return false; + } + + INDArray y = oc != null ? oc.getInputArray(1) : y(); + if (y != null && !y.isR()) return false; - if (y() != null && !y().isR()) - return false; - - if (z() != null && !z().isR()) + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + if (z != null && !z.isR()) return false; return true; @@ -157,15 +168,22 @@ public class Variance extends BaseReduceOp { @Override public List calculateOutputShape() { - if(args().length < 1) { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext oc) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + + if(oc == null && args().length < 1) { throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); } long[] argShape = arg().getShape(); - if (argShape == null && x() == null) { + if (argShape == null && x == null) { return Collections.emptyList(); } - long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x().shape() : argShape); + long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x.shape() : argShape); val ret = new ArrayList(1); val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java index 9c8607b98..765ab3341 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -94,20 +95,29 @@ public class MaxOut extends BaseTransformOp { return Nd4j.defaultFloatingPointType(); } + @Override + public DataType resultType(OpContext oc) { + return Nd4j.defaultFloatingPointType(); + } + @Override public Type getOpType() { return Type.TRANSFORM_STRICT; } @Override - public boolean validateDataTypes(boolean experimentalMode) { - if (!x().isR()) + public boolean validateDataTypes(OpContext oc, boolean experimentalMode) { + INDArray x = oc != null ? oc.getInputArray(0) : x(); + INDArray y = oc != null ? oc.getInputArray(1) : y(); + INDArray z = oc != null ? oc.getOutputArray(0) : z(); + + if (!x.isR()) return false; - if (y() != null && !y().isR()) + if (y != null && !y().isR()) return false; - if (z() != null && z().dataType() != x().dataType()) + if (z != null && z().dataType() != x().dataType()) return false; return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java index 51f682876..752881c6e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.RandomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -65,6 +66,11 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { @Override public List calculateOutputShape() { + return calculateOutputShape(null); + } + + @Override + public List calculateOutputShape(OpContext opContext) { if(shape != null){ return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); } else { @@ -83,4 +89,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { public boolean isInPlace(){ return x == null || x == z || x.data().pointer().address() == z.data().pointer().address(); } + + public boolean isTripleArgRngOp(){ + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index 35c6ee05e..b08f56be3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -139,4 +139,9 @@ public class BinomialDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java index ed43f807d..1081e141b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java @@ -138,4 +138,9 @@ public class GaussianDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java index 080a7305a..c007d4e92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java @@ -135,4 +135,9 @@ public class LogNormalDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java index 24e52a532..ba09a2d29 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java @@ -136,4 +136,9 @@ public class TruncatedNormalDistribution extends BaseRandomOp { //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 return Collections.singletonList(DataType.DOUBLE); } + + @Override + public boolean isTripleArgRngOp() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index bafee4003..5da64dadb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -6556,6 +6556,10 @@ public class Nd4j { return getExecutioner().exec(op); } + public static INDArray exec(Op op, OpContext context){ + return getExecutioner().exec(op, context); + } + /** * Execute the operation and return the result * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java index fa92f94f5..d34e24def 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java @@ -54,7 +54,7 @@ public abstract class Nd4jBlas implements Blas { } String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION); - if(logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit)) { + if(logOpenMPBlasThreads() && (logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit))) { log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); } } @@ -74,4 +74,8 @@ public abstract class Nd4jBlas implements Blas { } return Vendor.values()[vendor]; } + + public boolean logOpenMPBlasThreads(){ + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java index ce5ac2a0d..624460b50 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/CudaBlas.java @@ -134,4 +134,9 @@ public class CudaBlas extends Nd4jBlas { public int getBlasVendorId() { return 1; } + + @Override + public boolean logOpenMPBlasThreads() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index f18bd1459..a6ccd25ed 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -44,6 +44,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.random.BaseRandomOp; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -229,7 +230,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { INDArray ret = op.z(); checkForCompression(op); - op.validateDataTypes(); + op.validateDataTypes(null); //validateDataType(Nd4j.dataType(), op); for (int i = 0; i < dimension.length; i++) @@ -614,8 +615,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(Op op) { + return exec(op, null); + } + + @Override + public INDArray exec(Op op, OpContext oc) { checkForCompression(op); + /* + //TODO this never would have worked //linear views and oblong offsets can't be handled by the gpu (due to the way the buffers are interpreted as vectors) if ( op instanceof CopyOp) { // we dont' care about op.Z sync state, since it'll be overwritten @@ -631,27 +639,27 @@ public class CudaExecutioner extends DefaultOpExecutioner { //AtomicAllocator.getInstance().tickHostWrite(op.z()); return null; - } + }*/ if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; - invoke(t); + invoke(t, oc); } else if (op instanceof ReduceOp) { ReduceOp acc = (ReduceOp) op; - invoke(acc, acc.dimensions().toIntVector()); + invoke(acc, oc, acc.dimensions().toIntVector()); } else if (op instanceof ScalarOp) { ScalarOp sc = (ScalarOp) op; - invoke(sc); + invoke(sc, oc); } else if (op instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) op; - invoke(broadcastOp); + invoke(broadcastOp, oc); } else if (op instanceof IndexAccumulation) { IndexAccumulation indexAccumulation = (IndexAccumulation) op; - invoke(indexAccumulation, indexAccumulation.dimensions().toIntVector()); + invoke(indexAccumulation, oc, indexAccumulation.dimensions().toIntVector()); } else if (op instanceof RandomOp) { - exec((RandomOp) op); + exec((RandomOp) op, oc, Nd4j.getRandom()); } else if (op instanceof CustomOp) { - exec((CustomOp) op); + exec((CustomOp) op, oc); } @@ -659,19 +667,22 @@ public class CudaExecutioner extends DefaultOpExecutioner { } - @Override public TransformOp execAndReturn(TransformOp op) { checkForCompression(op); - invoke(op); + invoke(op, null); return op; } - protected CudaContext invoke(BroadcastOp op) { + protected CudaContext invoke(BroadcastOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + checkForCompression(op); //validateDataType(Nd4j.dataType(), op); @@ -684,17 +695,17 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); val hostXShapeInfo = - op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); + x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); val hostYShapeInfo = - op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); + y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); val hostZShapeInfo = - op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - val tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension()); + val tadBuffers = tadManager.getTADOnlyShapeInfo(x, op.getDimension()); val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); @@ -706,13 +717,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer devTadOffsetsZ = null; // that's the place where we're going to have second TAD in place - val tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), op.getDimension()); + val tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, op.getDimension()); devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context); devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context); PointerPointer xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), // 0 + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0 context.getOldStream(), // 1 AtomicAllocator.getInstance().getDeviceIdPointer(), // 2 context.getBufferAllocation(), // 3 @@ -727,30 +738,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { devTadShapeInfoZ, // 12 devTadOffsetsZ); // 13 - Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); + Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context); - Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); //log.info("X: {}; Y: {}; Z: {}; dTS: {}, dTO: {}; dTSz: {}; dTOz: {};", x.address(), y.address(), z.address(), devTadShapeInfo.address(), devTadOffsets.address(), devTadShapeInfoZ.address(), devTadOffsetsZ.address()); switch (op.getOpType()) { case BROADCAST: nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, null, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; @@ -768,11 +779,16 @@ public class CudaExecutioner extends DefaultOpExecutioner { - protected CudaContext invoke(IndexAccumulation op, int[] dimension) { - dimension = Shape.normalizeAxis(op.x().rank(), dimension); + protected CudaContext invoke(IndexAccumulation op, OpContext oc, int[] dimension) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + + dimension = Shape.normalizeAxis(x.rank(), dimension); if (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) { - if(op.z() == op.x() || op.z() == null) { - op.setZ(Nd4j.createUninitialized(DataType.LONG, new long[0], 'c')); + if(z == x || z == null) { + z = Nd4j.createUninitialized(DataType.LONG, new long[0], 'c'); + setZ(z, op, oc); } } @@ -790,46 +806,45 @@ public class CudaExecutioner extends DefaultOpExecutioner { CudaEnvironment.getInstance().getConfiguration().enableDebug(true); if (dimension != null) for (int i = 0; i < dimension.length; i++) - if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE) - throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); + if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE) + throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + x.rank() + "]"); val context = AtomicAllocator.getInstance().getDeviceContext(); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null; + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); + Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(x.dataType()), context) : null; - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); int fdimension[] = dimension; if (fdimension == null) fdimension = new int[] {0}; - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, fdimension); Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); DataBuffer offsets = tadBuffers.getSecond(); Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); PointerPointer xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets); - if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) { + if (z.isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) { nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { if (dimension != null && dimension.length > 1) Arrays.sort(dimension); @@ -839,9 +854,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension)); nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } @@ -855,30 +870,34 @@ public class CudaExecutioner extends DefaultOpExecutioner { } - protected CudaContext invoke(ReduceOp op, int[] dimension) { + protected CudaContext invoke(ReduceOp op, OpContext oc, int[] dimension) { val context = AtomicAllocator.getInstance().getDeviceContext(); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" - if(op.z() != null){ - Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + - " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); - op.z().assign(op.x()); + if(z != null){ + Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", x, z); + z.assign(x); return context; } else { - op.setZ(op.x().dup()); + op.setZ(x.dup()); return context; } } // FIXME: this should be moved down to C++ on per-op basis // reduce to scalar case, ReduceBool ops require special treatment - if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { - if (op.z() == null) { + if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (z == null) { op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); } else { - op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + z.assign(((BaseReduceBoolOp) op).emptyValue()); } return context; @@ -888,7 +907,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { checkForCompression(op); - dimension = Shape.normalizeAxis(op.x().rank(), dimension); + dimension = Shape.normalizeAxis(x.rank(), dimension); //validateDataType(Nd4j.dataType(), op); @@ -903,130 +922,131 @@ public class CudaExecutioner extends DefaultOpExecutioner { Arrays.sort(dimension); for (int i = 0; i < dimension.length; i++) - if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE) + if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE) throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) - + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); + + " contains element that higher then rank of op.X: [" + x.rank() + "]"); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - val tadBuffers = op.x().isEmpty() ? Pair.makePair(op.x().data(), null) : tadManager.getTADOnlyShapeInfo(op.x(), dimension); + val tadBuffers = x.isEmpty() ? Pair.makePair(x.data(), null) : tadManager.getTADOnlyShapeInfo(x, dimension); val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); - val offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); + val offsets = x.isEmpty() ? null : tadBuffers.getSecond(); val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); - long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims()); + long[] retShape = Shape.reductionShape(x, dimension, true, op.isKeepDims()); - if (op.y() != null) { + if (y != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if (op.x().length() == op.y().length()) { + if (x.length() == y.length()) { //Pairwise - if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { + if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + - Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(dimension) + ")"); } } else { //Every X TAD vs. entirety of Y - val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); + val xTADSize = x.length() / x.tensorsAlongDimension(dimension); - if (xTADSize != op.y().length()) { + if (xTADSize != y.length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); + " (x TAD size = " + xTADSize + ", y size = " + y.length()); } } } - //if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) { + //if (x.isVector() && x.length() == ArrayUtil.prod(retShape)) { // return null; //} - val dataType = op.resultType(); + val dataType = oc != null ? op.resultType(oc) : op.resultType(); - if( op.z() == null ){ + if( z == null ){ val ret = Nd4j.createUninitialized(dataType, retShape); - op.setZ(ret); - } else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){ + setZ(ret, op, oc); + z = ret; + } else if(z.dataType() != dataType || !Arrays.equals(retShape, z.shape())){ throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape) - + " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape())); + + " but has datatype " + z.dataType() + " and shape " + Arrays.toString(z.shape())); } - val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType()); + val eb = op.extraArgsDataBuff(z.dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? x.dataType() : z.dataType()); Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null; - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); val xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets); - val yTadBuffers = op.y() == null ? null : tadManager.getTADOnlyShapeInfo(op.y(), dimension); + val yTadBuffers = y == null ? null : tadManager.getTADOnlyShapeInfo(y, dimension); - val yDevTadShapeInfo = op.y() == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context); - val yOffsets = op.y() == null ? null : yTadBuffers.getSecond(); + val yDevTadShapeInfo = y == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context); + val yOffsets = y == null ? null : yTadBuffers.getSecond(); val yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context); - if (op.y() != null) { + if (y != null) { xShapeInfoHostPointer.put(12, yDevTadShapeInfo); xShapeInfoHostPointer.put(13, yDevTadOffsets); } - val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); - op.validateDataTypes(); + op.validateDataTypes(null); - if (op.z().isScalar()) { + if (z.isScalar()) { if (op instanceof Variance) { nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((Variance) op).isBiasCorrected()); - } else if (op.y() != null) { - Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); + } else if (y != null) { + Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context); nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_BOOL: nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_SAME: nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_LONG: nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; default: throw new UnsupportedOperationException(); @@ -1035,21 +1055,21 @@ public class CudaExecutioner extends DefaultOpExecutioner { } else { val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); - if (op.y() != null) { - val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); + if (y != null) { + val yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context); nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); } else { if (op instanceof Variance) { nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected(), (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets); @@ -1057,30 +1077,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: @@ -1187,34 +1207,40 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(ScalarOp op) { - invoke(op); + invoke(op, null); return op.z(); } - protected CudaContext invoke(ScalarOp op) { + protected CudaContext invoke(ScalarOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); checkForCompression(op); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + // validateDataType(Nd4j.dataType(), op); - if(op.z() == null){ + if(z == null){ switch (op.getOpType()) { case SCALAR: - op.setZ(op.x().ulike()); + z = x.ulike(); + setZ(x.ulike(), op, oc); break; case SCALAR_BOOL: - op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + z = Nd4j.createUninitialized(DataType.BOOL, x.shape()); + setZ(z, op, oc); break; default: throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); } } - if (op.x().length() != op.z().length()) + if (x.length() != z.length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" - + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" - + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]"); + + Arrays.toString(x.shapeInfoDataBuffer().asInt()) + "] != [" + + Arrays.toString(z.shapeInfoDataBuffer().asInt()) + "]"); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -1229,38 +1255,38 @@ public class CudaExecutioner extends DefaultOpExecutioner { val context = AtomicAllocator.getInstance().getDeviceContext(); - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? op.x().dataType() : op.z().dataType()), context) : null; + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); + Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? x.dataType() : z.dataType()), context) : null; - Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context); PointerPointer xShapeInfoHostPointer = extraz.get().put( - AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), + AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR_BOOL: nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; case SCALAR: nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; default: @@ -1275,9 +1301,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { return null; } - protected CudaContext invoke(TransformOp op) { + protected CudaContext invoke(TransformOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + checkForCompression(op); //validateDataType(Nd4j.dataType(), op); @@ -1295,7 +1325,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { // special temp array for IsMax along dimension INDArray ret = null; - Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context); + Pointer xShapeInfo = allocator.getPointer(x.shapeInfoDataBuffer(), context); Pointer dimensionDevPointer = null; @@ -1304,17 +1334,18 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer retHostShape = null; int dimension[] = null; - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + var hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); - if (op.z() == null) { - ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering()); - op.setZ(ret); + if (z == null) { + ret = Nd4j.createUninitialized(op.resultType(), x.shape(), x.ordering()); + setZ(ret, op, oc); + z = ret; } - var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? op.x().dataType() : op.z().dataType()), context) : null; - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? x.dataType() : z.dataType()), context) : null; + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); Pointer hostTadShapeInfo = null; Pointer devTadShapeInfo = null; @@ -1328,13 +1359,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer devTadOffsets = null; Pointer devMaxTadOffsets = null; - op.validateDataTypes(experimentalMode.get()); + op.validateDataTypes(oc, experimentalMode.get()); - Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context); + Pointer zShapeInfo = allocator.getPointer(z.shapeInfoDataBuffer(), context); PointerPointer xShapeInfoHostPointer = - extraz.get().put(AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), // 0 + extraz.get().put(AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0 context.getOldStream(), // 1 allocator.getDeviceIdPointer(), // 2 context.getBufferAllocation(), // 3 @@ -1356,30 +1387,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { retHostShape); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.y() != null) { - Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context); + if (y != null) { + Pointer yShapeInfo = allocator.getPointer(y.shapeInfoDataBuffer(), context); - if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) + if (x.length() != y.length() || x.length() != z.length()) throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform"); switch (op.getOpType()) { case TRANSFORM_BOOL: case PAIRWISE_BOOL: nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; } @@ -1387,32 +1418,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case TRANSFORM_ANY: nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_FLOAT: nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_BOOL: nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_SAME: nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_STRICT: nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(), - x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, - z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: @@ -1478,6 +1509,21 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(RandomOp op, Random rng) { + return exec(op, null, rng); + } + + public INDArray exec(RandomOp op, OpContext oc, Random rng){ + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + + if(op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null){ + //Ugly hack to ensure the triple arg call occurs + //See GaussianDistribution.setZ etc + x = z; + y = z; + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -1496,38 +1542,38 @@ public class CudaExecutioner extends DefaultOpExecutioner { val context = AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), + PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); - val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); - val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); + val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); + val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.x() != null && op.y() != null && op.z() != null) { + if (x != null && y != null && z != null) { // triple arg call nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); + xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), + yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context), + zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), + AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context)); - } else if (op.x() != null && op.z() != null) { + } else if (x != null && z != null) { //double arg call nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()),context)); + xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), + zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), + AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()),context)); } else { // single arg call nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); + zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), + AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context)); } if (nativeOps.lastErrorCode() != 0) @@ -1535,7 +1581,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { profilingConfigurableHookOut(op, st); - return op.z(); + return z; } /** @@ -1888,6 +1934,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public List calculateOutputShape(@NonNull CustomOp op) { + return calculateOutputShape(op, null); + } + + @Override + public List calculateOutputShape(@NonNull CustomOp op, OpContext opContext){ Nd4j.getExecutioner().commit(); @@ -1895,7 +1946,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hash = op.opHash(); val result = new ArrayList(); - if(op.numInputArguments() < 1 && op.getDescriptor().getNumInputs() != -2) { + int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); + if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); @@ -1903,47 +1955,75 @@ public class CudaExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - val inputBuffers = new PointerPointer<>(op.inputArguments().size() * 2); - val inputShapes = new PointerPointer<>(op.inputArguments().size()); + val inputBuffers = new PointerPointer<>(nIn * 2); + val inputShapes = new PointerPointer<>(nIn); + val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments(); int cnt= 0; - for (val in: op.inputArguments()) { + for (val in: inputArgs) { // NOT A TYPO: shape functions work on host side only if (!in.isEmpty()) { inputBuffers.put(cnt, in.data().addressPointer()); - inputBuffers.put(cnt + op.inputArguments().size(), AtomicAllocator.getInstance().getPointer(in.data())); + inputBuffers.put(cnt + nIn, AtomicAllocator.getInstance().getPointer(in.data())); } inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer()); } - val iArgs = op.iArgs().length > 0 ? new LongPointer(op.iArgs().length) : null; + int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments(); + val iArgs = nIArgs > 0 ? new LongPointer(nIArgs) : null; cnt = 0; - for (val i: op.iArgs()) - iArgs.put(cnt++, i); + if(opContext != null){ + for (val i: opContext.getIArguments()) + iArgs.put(cnt++, i); + } else { + for (val i: op.iArgs()) + iArgs.put(cnt++, i); + } - val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null; + int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments(); + val tArgs = nTArgs > 0 ? new DoublePointer(nTArgs) : null; - val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null; + int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments(); + val bArgs = nBArgs > 0 ? new BooleanPointer(nBArgs) : null; - val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments(); + val dArgs = nDArgs > 0 ? new IntPointer(nDArgs) : null; cnt = 0; - for (val b: op.bArgs()) - bArgs.put(cnt++, b); + if(opContext != null){ + for (val b: opContext.getBArguments()) + bArgs.put(cnt++, b); + } else { + for (val b: op.bArgs()) + bArgs.put(cnt++, b); + } + cnt = 0; - for (val t: op.tArgs()) - tArgs.put(cnt++, t); + if(opContext != null){ + for (val b: opContext.getTArguments()) + tArgs.put(cnt++, b); + } else { + for (val b: op.tArgs()) + tArgs.put(cnt++, b); + } cnt = 0; - val dArgs1 = op.dArgs(); - for (val d: dArgs1) - dArgs.put(cnt++, d.toInt()); + if(opContext != null){ + for (val b: opContext.getDArguments()) + dArgs.put(cnt++, b.toInt()); + } else { + for (val b: op.dArgs()) + dArgs.put(cnt++, b.toInt()); + } - OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments()); + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, + hash, inputBuffers, inputShapes, nIn, tArgs, nTArgs, + iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs); +// OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments()); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java index 850096359..1d8a3de65 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java @@ -20,6 +20,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import lombok.val; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.primitives.Pair; @@ -127,7 +128,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // the only entry place for TADless ops processAsGridOp(op); } else if (op instanceof BroadcastOp) { - invoke((BroadcastOp) op); + invoke((BroadcastOp) op, null); } else { //logger.info("Random op: {}", op.getClass().getSimpleName()); pushToGrid(new OpDescriptor(op)); @@ -238,7 +239,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio flushQueue(); //logger.info("Sending TransformOp to CudaExecutioner"); - super.invoke(t); + super.invoke(t, null); } else if (op instanceof Variance) { Variance acc = (Variance) op; if (flush) @@ -258,7 +259,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio flushQueue(); //logger.info("Sending ScalarOp to CudaExecutioner"); - super.invoke(sc); + super.invoke(sc, null); } else if (op instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) op; if (flush) @@ -268,7 +269,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio if (dimensions != null) { super.exec(broadcastOp); } else { - super.invoke(broadcastOp); + super.invoke(broadcastOp, null); } } else if (op instanceof IndexAccumulation) { IndexAccumulation indexAccumulation = (IndexAccumulation) op; @@ -690,7 +691,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio flushQueue(); buildZ(op, new int[] {Integer.MAX_VALUE}); - super.invoke(op, new int[] {Integer.MAX_VALUE}); + super.invoke(op, null, new int[] {Integer.MAX_VALUE}); } else { buildZ(op, dimension); processAsGridOp(op, dimension); @@ -708,7 +709,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // FIXME: remove CudaContext return opType. We just don't need it @Override - protected CudaContext invoke(BroadcastOp op) { + protected CudaContext invoke(BroadcastOp op, OpContext oc) { + Preconditions.checkState(oc == null); processAsGridOp(op, op.getDimension()); return null; @@ -716,7 +718,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // FIXME: remove CudaContext return opType. We just don't need it @Override - protected CudaContext invoke(ScalarOp op) { + protected CudaContext invoke(ScalarOp op, OpContext oc) { + Preconditions.checkState(oc == null); processAsGridOp(op, null); return null; @@ -724,7 +727,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // FIXME: remove CudaContext return opType. We just don't need it @Override - protected CudaContext invoke(TransformOp op) { + protected CudaContext invoke(TransformOp op, OpContext oc) { + Preconditions.checkState( oc == null); processAsGridOp(op, null); return null; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index cc3d17b5f..7a29f71d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -44,6 +44,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.random.BaseRandomOp; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -135,26 +136,31 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(Op op) { + return exec(op, null); + } + + @Override + public INDArray exec(Op op, OpContext opContext) { checkForCompression(op); if (op instanceof ScalarOp) { ScalarOp s = (ScalarOp) op; - exec(s); + exec(s, opContext); } else if (op instanceof TransformOp) { TransformOp t = (TransformOp) op; - exec(t); + exec(t, opContext); } else if (op instanceof ReduceOp) { ReduceOp ac = (ReduceOp) op; - exec(ac); + exec(ac, opContext); } else if (op instanceof IndexAccumulation) { IndexAccumulation iac = (IndexAccumulation) op; - exec(iac); //Currently using DefaultOpExecutioner + exec(iac, opContext); //Currently using DefaultOpExecutioner } else if (op instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) op; - exec(broadcastOp); + exec(broadcastOp, opContext); } else if (op instanceof RandomOp) { RandomOp rngOp = (RandomOp) op; - exec(rngOp, Nd4j.getRandom()); + exec(rngOp, opContext, Nd4j.getRandom()); } return op.z(); @@ -163,36 +169,44 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(IndexAccumulation op) { + return exec(op, null); + } + + public INDArray exec(IndexAccumulation op, OpContext oc) { checkForCompression(op); + INDArray x = getX(op, oc); + INDArray z = getZ(op, oc); + if (extraz.get() == null) extraz.set(new PointerPointer(32)); - val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + val dimension = Shape.normalizeAxis(x.rank(), op.dimensions().toIntVector()); - if (op.x().isEmpty()) { + if (x.isEmpty()) { for (val d:dimension) { - Preconditions.checkArgument(op.x().shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); + Preconditions.checkArgument(x.shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); } } boolean keepDims = op.isKeepDims(); - long[] retShape = Shape.reductionShape(op.x(), dimension, true, keepDims); + long[] retShape = Shape.reductionShape(x, dimension, true, keepDims); - if(op.z() == null || op.x() == op.z()) { + if(z == null || x == z) { val ret = Nd4j.createUninitialized(DataType.LONG, retShape); - op.setZ(ret); - } else if(!Arrays.equals(retShape, op.z().shape())){ + setZ(ret, op, oc); + z = ret; + } else if(!Arrays.equals(retShape, z.shape())){ throw new IllegalStateException("Z array shape does not match expected return type for op " + op - + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(op.z().shape())); + + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape())); } op.validateDataTypes(); Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, dimension); Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); @@ -203,19 +217,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long st = profilingConfigurableHookIn(op, tadBuffers.getFirst()); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.z().isScalar()) { + if (z.isScalar()) { loop.execIndexReduceScalar(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null); } else { loop.execIndexReduce(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } @@ -223,7 +237,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new RuntimeException(loop.lastErrorMessage()); profilingConfigurableHookOut(op, st); - return op.z(); + return getZ(op, oc); } @Override @@ -233,34 +247,41 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(ReduceOp op) { - Preconditions.checkNotNull(op.x(), "Op.x() cannot be null: Was null for op %s", op); - op.validateDataTypes(); + return exec(op, null); + } + + public INDArray exec(ReduceOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + Preconditions.checkNotNull(x, "Op.x() cannot be null: Was null for op %s", op); + op.validateDataTypes(oc); if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" - if(op.z() != null){ - Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + - " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); - op.z().assign(op.x()); - return op.z(); + if(z != null){ + Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", x, z); + z.assign(x); + return z; } else { - op.setZ(op.x().dup()); - return op.z(); + setZ(x.dup(), op, oc); + return z; } } // FIXME: this should be moved down to C++ on per-op basis - val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + val dimension = Shape.normalizeAxis(x.rank(), op.dimensions().toIntVector()); // reduce to scalar case, ReduceBool ops require special treatment - if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { - if (op.z() == null) { - op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); + if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (z == null) { + setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()), op, oc); } else { - op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + z.assign(((BaseReduceBoolOp) op).emptyValue()); } - return op.z(); + return z; } //validateDataType(Nd4j.dataType(), op); @@ -269,10 +290,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { extraz.set(new PointerPointer(32)); boolean keepDims = op.isKeepDims(); - long[] retShape = Shape.reductionShape(op.x(), dimension, true, keepDims); + long[] retShape = Shape.reductionShape(x, dimension, true, keepDims); - if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && op.y() == null) + if (x.isVector() && x.length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && y == null) return op.noOp(); /** @@ -280,92 +301,94 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * We create it only if we hadn't provided it before */ INDArray ret; - if (op.z() == null || op.z() == op.x()) { + if (z == null || z == x) { if (op.isComplexAccumulation()) { - long xT = op.x().tensorsAlongDimension(dimension); - long yT = op.y().tensorsAlongDimension(dimension); + long xT = x.tensorsAlongDimension(dimension); + long yT = y.tensorsAlongDimension(dimension); ret = Nd4j.create(op.resultType(), new long[]{xT, yT}); } else { - if (op.y() != null) { + if (y != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if(op.x().length() == op.y().length()) { + if(x.length() == y.length()) { //Pairwise - if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { + if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + - Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(dimension) + ")"); } } else { //Every X TAD vs. entirety of Y - val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension); + val xTADSize = x.length() / x.tensorsAlongDimension(dimension); - if (xTADSize != op.y().length()) { + if (xTADSize != y.length()) { throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" + - " (x TAD size = " + xTADSize + ", y size = " + op.y().length()); + " (x TAD size = " + xTADSize + ", y size = " + y.length()); } } } - ret = Nd4j.create(op.resultType(), retShape); + DataType dt = oc != null ? op.resultType(oc) : op.resultType(); + ret = Nd4j.create(dt, retShape); } - op.setZ(ret); + setZ(ret, op, oc); + z = ret; } else { // compare length long shapeProduct = (retShape.length == 0 ? 1 : ArrayUtil.prodLong(retShape)); - if (!op.isComplexAccumulation() && op.z().length() != shapeProduct) { - if(!(op.x().isEmpty() && op.isKeepDims())){ + if (!op.isComplexAccumulation() && z.length() != shapeProduct) { + if(!(x.isEmpty() && op.isKeepDims())){ //Empty reductions are special case: [1,0].sum(0,1,keep=true) -> shape [1,1] - throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); + throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(z.shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); } } else if (op.isComplexAccumulation()) { - long xT = op.x().tensorsAlongDimension(dimension); - long yT = op.y().tensorsAlongDimension(dimension); + long xT = x.tensorsAlongDimension(dimension); + long yT = y.tensorsAlongDimension(dimension); - if (op.z().length() != xT * yT) - throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + (xT * yT) + "]"); + if (z.length() != xT * yT) + throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(z.shape()) + "] doesn't match expected [" + (xT * yT) + "]"); } - ret = op.z(); + ret = z; } - //log.info("X dtype: {}; Z dtype: {}", op.x().dataType(), op.z().dataType()); + //log.info("X dtype: {}; Z dtype: {}", x.dataType(), z.dataType()); /** * Returns the {@link Shape#createShapeInformation(int[], int[], int, int, char)} * and the associated offsets for each {@link INDArray#tensorAlongDimension(int, int...)} * The first item is the shape information. The second one is the offsets. */ - Pair tadBuffers = op.x().isEmpty() ? Pair.makePair(op.x().data(), null): tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = x.isEmpty() ? Pair.makePair(x.data(), null): tadManager.getTADOnlyShapeInfo(x, dimension); Pair yTadBuffers = null; /** * Note that we use addresses in libnd4j. * We use reinterpret cast in c to take the long * we pass to JNI. This manages overhead. */ - Pointer hostTadShapeInfo = op.x().isEmpty() ? op.x().shapeInfoDataBuffer().addressPointer() : tadBuffers.getFirst().addressPointer(); + Pointer hostTadShapeInfo = x.isEmpty() ? x.shapeInfoDataBuffer().addressPointer() : tadBuffers.getFirst().addressPointer(); - DataBuffer offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); + DataBuffer offsets = x.isEmpty() ? null : tadBuffers.getSecond(); Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer(); // we're going to check, if that's TAD vs TAD comparison or TAD vs full array. if later - we're going slightly different route boolean tvf = false; - if (op.y() != null) { - if (op.x().tensorAlongDimension(0, dimension).length() == op.y().length()) { + if (y != null) { + if (x.tensorAlongDimension(0, dimension).length() == y.length()) { tvf = true; } } if (op.isComplexAccumulation()) { - yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension); + yTadBuffers = tadManager.getTADOnlyShapeInfo(y, dimension); - if (op.x().tensorAlongDimension(0, dimension).length() != op.y().tensorAlongDimension(0, dimension).length()) + if (x.tensorAlongDimension(0, dimension).length() != y.tensorAlongDimension(0, dimension).length()) throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension: " + - "x TAD length = " + op.x().tensorAlongDimension(0, dimension).length() + ", y TAD length " + - op.y().tensorAlongDimension(0, dimension).length()); + "x TAD length = " + x.tensorAlongDimension(0, dimension).length() + ", y TAD length " + + y.tensorAlongDimension(0, dimension).length()); } /** @@ -383,23 +406,23 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * This gives us a pointer which is passed around in libnd4j. */ Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); if (op instanceof Variance) { if (ret.isScalar()) { loop.execSummaryStatsScalar(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected()); } else { Variance var = (Variance) op; try { loop.execSummaryStatsTad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, var.isBiasCorrected(), null, null); } catch (Throwable t){ @@ -410,15 +433,15 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } //pairwise reduction like similarity of two arrays - else if (op.y() != null && op.getOpType() == Op.Type.REDUCE3) { - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + else if (y != null && op.getOpType() == Op.Type.REDUCE3) { + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); if (op.isComplexAccumulation()) { try { loop.execReduce3All(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), (LongPointer) yTadBuffers.getFirst().addressPointer(), new LongPointerWrapper(yTadBuffers.getSecond().addressPointer()) @@ -429,17 +452,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } else if (ret.isScalar()) { loop.execReduce3Scalar(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); } else { try { loop.execReduce3Tad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, null, null, null, null); } catch (Throwable t){ @@ -453,27 +476,27 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unsupported op used in reduce: "+ op.getOpType()); @@ -482,32 +505,32 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool2(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType()), + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; @@ -520,7 +543,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - return ret; + return getZ(op, oc); } /** @@ -528,6 +551,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * @param op Op to execute */ private void invokeScalarAlongDimension(ScalarOp op) { + invokeScalarAlongDimension(op, null); + } + + private void invokeScalarAlongDimension(ScalarOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + val dimension = op.dimensions().toIntVector(); //dimension = Shape.normalizeAxis(op.x().rank(), dimension); // do tad magic @@ -561,16 +592,16 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR: loop.execScalarTad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, @@ -578,9 +609,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { break; case SCALAR_BOOL: loop.execScalarBoolTad(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, @@ -594,56 +625,63 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new RuntimeException(loop.lastErrorMessage()); } - public INDArray exec(ScalarOp op) { + public INDArray exec(ScalarOp op){ + return exec(op, null); + } + + public INDArray exec(ScalarOp op, OpContext oc) { long st = profilingConfigurableHookIn(op); //validateDataType(Nd4j.dataType(), op); - if(op.z() == null){ + if((oc != null && oc.getOutputArray(0) == null) || getZ(op, oc) == null){ switch (op.getOpType()) { case SCALAR: - op.setZ(op.x().ulike()); + setZ(getX(op, oc).ulike(), op, oc); +// op.setZ(op.x().ulike()); break; case SCALAR_BOOL: - op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); +// op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + setZ(Nd4j.createUninitialized(DataType.BOOL, getX(op, oc).shape()), op, oc); break; default: throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); } } - if (op.x().length() != op.z().length()) +// if (op.x().length() != op.z().length()) + if (getX(op, oc).length() != getZ(op, oc).length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " + - "x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = [" - + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "], z shape info = [" - + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]"); + "x.length()=" + getX(op, oc).length() + ", z.length()=" + getZ(op, oc).length() + " - x shape info = [" + + Arrays.toString(getX(op, oc).shapeInfoDataBuffer().asInt()) + "], z shape info = [" + + Arrays.toString(getZ(op, oc).shapeInfoDataBuffer().asInt()) + "]"); if (op.dimensions() != null) { invokeScalarAlongDimension(op); - return op.z(); + return getZ(op, oc); } - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val x = ((BaseCpuDataBuffer) getX(op, oc).data()).getOpaqueDataBuffer(); val scalar = ((BaseCpuDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) getZ(op, oc).data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR: loop.execScalar(null, - op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType())); + op.opNum(), + x, (LongPointer) getX(op, oc).shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) getZ(op, oc).shapeInfoDataBuffer().addressPointer(), null, + scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, getZ(op, oc).dataType())); break; case SCALAR_BOOL: loop.execScalarBool(null, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + x, (LongPointer) getX(op, oc).shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) getZ(op, oc).shapeInfoDataBuffer().addressPointer(), null, scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType())); + getPointerForExtraArgs(op, getX(op, oc).dataType())); break; default: throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); @@ -654,7 +692,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { profilingConfigurableHookOut(op, st); - return op.z(); + return getZ(op, oc); } private Pointer getPointerForExtraArgs(Op op, DataType type) { @@ -670,6 +708,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } private void exec(TransformOp op) { + exec(op, null); + } + + private void exec(TransformOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + long st = 0; // validateDataType(Nd4j.dataType(), op); @@ -681,8 +727,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { // Pow operations might be special if (op.opNum() == 31) { - if (op.y() != null && op.y().isScalar()) { - op.setY(Nd4j.valueArrayOf(op.x().shape(), op.y().getDouble(0))); + if (y != null && y.isScalar()) { +// op.setY(Nd4j.valueArrayOf(op.x().shape(), op.y().getDouble(0))); + setY(Nd4j.valueArrayOf(x.shape(), y.getDouble(0)), op, oc); } } @@ -723,33 +770,26 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } else st = profilingConfigurableHookIn(op); - if (op.y() != null) { + if (y != null) { - if (op.z() == null) - op.setZ(Nd4j.create(op.resultType(), op.x().shape())); + if (z == null) + setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); +// op.setZ(Nd4j.create(op.resultType(), op.x().shape())); - op.validateDataTypes(experimentalMode.get()); + op.validateDataTypes(oc, experimentalMode.get()); //log.info("X type: {}; Y type: {}; Z type: {}; OpNum: {}", op.x().dataType(), op.y().dataType(), op.z().dataType(), op.opNum()); - int xEWS = op.x().elementWiseStride(); - int yEWS = op.y().elementWiseStride(); - int zEWS = op.z().elementWiseStride(); - - boolean xRow = op.x().isRowVector(); - boolean yRow = op.y().isRowVector(); - boolean zRow = op.z().isRowVector(); - - if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) + if (x.length() != y.length() || x.length() != z.length()) throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform " + - op.opName() + ". x: length " + op.x().length() + ", shape " + Arrays.toString(op.x().shape()) + - "; y: " + op.y().length() + ", shape " + Arrays.toString(op.y().shape()) + - "; z: " + op.z().length() + ", shape " + Arrays.toString(op.z().shape())); + op.opName() + ". x: length " + x.length() + ", shape " + Arrays.toString(x.shape()) + + "; y: " + y.length() + ", shape " + Arrays.toString(y.shape()) + + "; z: " + z.length() + ", shape " + Arrays.toString(z.shape())); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case TRANSFORM_ANY: @@ -757,78 +797,81 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { case TRANSFORM_STRICT: case TRANSFORM_SAME: if (!experimentalMode.get()) - Preconditions.checkArgument(op.x().dataType() == op.y().dataType() || op.y().dataType() == DataType.BOOL, "Op.X and Op.Y must have the same data type, but got " + op.x().dataType() + " vs " + op.y().dataType()); + Preconditions.checkArgument(x.dataType() == y.dataType() || y.dataType() == DataType.BOOL, + "Op.X and Op.Y must have the same data type, but got %s vs. %s", x.dataType(), y.dataType()); loop.execPairwiseTransform(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.z().dataType())); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, z.dataType())); break; case TRANSFORM_BOOL: case PAIRWISE_BOOL: loop.execPairwiseTransformBool(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - getPointerForExtraArgs(op, op.x().dataType())); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, x.dataType())); break; } } else { - if (op.z() == null) - op.setZ(Nd4j.createUninitialized(op.resultType(), op.x().shape())); + if (z == null) { + setZ(Nd4j.createUninitialized((oc != null ? op.resultType(oc) : op.resultType()), x.shape()), op, oc); + z = getZ(op, oc); + } - op.validateDataTypes(experimentalMode.get()); + op.validateDataTypes(oc, experimentalMode.get()); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case TRANSFORM_FLOAT: { - val xtraz = getPointerForExtraArgs(op, op.z().dataType()); + val xtraz = getPointerForExtraArgs(op, z.dataType()); loop.execTransformFloat(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_STRICT: { - val xtraz = getPointerForExtraArgs(op, op.z().dataType()); + val xtraz = getPointerForExtraArgs(op, z.dataType()); loop.execTransformStrict(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_SAME: { - val xtraz = getPointerForExtraArgs(op, op.z().dataType()); + val xtraz = getPointerForExtraArgs(op, z.dataType()); loop.execTransformSame(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_ANY: { - val xtraz = getPointerForExtraArgs(op, op.x().dataType()); + val xtraz = getPointerForExtraArgs(op, x.dataType()); val opNum = op.opNum(); loop.execTransformAny(dummy, opNum, - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } case TRANSFORM_BOOL: { - val xtraz = getPointerForExtraArgs(op, op.x().dataType()); + val xtraz = getPointerForExtraArgs(op, x.dataType()); val opNum = op.opNum(); loop.execTransformBool(dummy, opNum, - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -845,6 +888,14 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } public INDArray exec(BroadcastOp op) { + return exec(op, null); + } + + public INDArray exec(BroadcastOp op, OpContext oc) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + long st = profilingConfigurableHookIn(op); op.validateDataTypes(experimentalMode.get()); @@ -856,7 +907,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * and the associated offsets for each {@link INDArray#tensorAlongDimension(int, int...)} * The first item is the shape information. The second one is the offsets. */ - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, dimension); Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); Pointer hostTadOffsets = tadBuffers.getSecond().addressPointer(); @@ -864,17 +915,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Pointer devTadShapeInfoZ = null; Pointer devTadOffsetsZ = null; - // if (!Arrays.equals(op.x().shape(),op.z().shape()) || !Arrays.equals(op.x().stride(),op.z().stride()) || op.x().ordering() != op.z().ordering()) { + // if (!Arrays.equals(x.shape(),z.shape()) || !Arrays.equals(x.stride(),z.stride()) || x.ordering() != z.ordering()) { // that's the place where we're going to have second TAD in place - Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension); + Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, dimension); devTadShapeInfoZ = tadBuffersZ.getFirst().addressPointer(); devTadOffsetsZ = tadBuffersZ.getSecond().addressPointer(); /* log.info("Broascast dimension: {}", Arrays.toString(dimension)); - log.info("x shape: {}; x TAD: {}; comp TAD: {}", Arrays.toString(op.x().shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffers.getFirst().asInt()), Arrays.toString(op.x().tensorAlongDimension(0, dimension).shapeInfoDataBuffer().asInt())); - log.info("z shape: {}; z TAD: {}", Arrays.toString(op.z().shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffersZ.getFirst().asInt())); - log.info("y shape: {}", Arrays.toString(op.y().shapeInfoDataBuffer().asInt())); + log.info("x shape: {}; x TAD: {}; comp TAD: {}", Arrays.toString(x.shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffers.getFirst().asInt()), Arrays.toString(x.tensorAlongDimension(0, dimension).shapeInfoDataBuffer().asInt())); + log.info("z shape: {}; z TAD: {}", Arrays.toString(z.shapeInfoDataBuffer().asInt()), Arrays.toString(tadBuffersZ.getFirst().asInt())); + log.info("y shape: {}", Arrays.toString(y.shapeInfoDataBuffer().asInt())); log.info("-------------"); */ @@ -885,23 +936,23 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); - val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case BROADCAST: loop.execBroadcast(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: loop.execBroadcastBool(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; @@ -912,7 +963,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - return op.z(); + return z; } @@ -1202,6 +1253,22 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { */ @Override public INDArray exec(RandomOp op, Random rng) { + return exec(op, null, rng); + } + + + public INDArray exec(RandomOp op, OpContext oc, Random rng) { + INDArray x = getX(op, oc); + INDArray y = getY(op, oc); + INDArray z = getZ(op, oc); + + if(op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null){ + //Ugly hack to ensure the triple arg call occurs + //See GaussianDistribution.setZ etc + x = z; + y = z; + } + if (!(rng instanceof CpuNativeRandom)) throw new IllegalStateException( "You should use one of NativeRandom classes for NativeOperations execution. Op class: " + op.getClass().getName()); @@ -1210,30 +1277,30 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { //validateDataType(Nd4j.dataType(), op); - Preconditions.checkArgument(op.z().isR(), "Op.Z must have one of floating point types"); + Preconditions.checkArgument(z.isR(), "Op.Z must have one of floating point types"); - val x = op.x() == null ? null : ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); - val y = op.y() == null ? null : ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); - val z = op.z() == null ? null : ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + val xb = x == null ? null : ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); + val yb = y == null ? null : ((BaseCpuDataBuffer) y.data()).getOpaqueDataBuffer(); + val zb = z == null ? null : ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); - if (op.x() != null && op.y() != null && op.z() != null) { + if (x != null && y != null && z != null) { // triple arg call loop.execRandom3(null, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - op.extraArgsDataBuff(op.z().dataType()).addressPointer()); - } else if (op.x() != null && op.z() != null) { + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + yb, (LongPointer) y.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + op.extraArgsDataBuff(z.dataType()).addressPointer()); + } else if (x != null && z != null) { //double arg call loop.execRandom2(null, op.opNum(), rng.getStatePointer(), // rng state ptr - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - op.extraArgsDataBuff(op.z().dataType()).addressPointer()); + xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + op.extraArgsDataBuff(z.dataType()).addressPointer()); } else { // single arg call loop.execRandom(null, op.opNum(), rng.getStatePointer(), // rng state ptr - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, - op.extraArgsDataBuff(op.z().dataType()).addressPointer()); + zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, + op.extraArgsDataBuff(z.dataType()).addressPointer()); } if (loop.lastErrorCode() != 0) @@ -1241,7 +1308,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { profilingConfigurableHookOut(op, st); - return op.z(); + return z; } @Override @@ -1678,11 +1745,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public List calculateOutputShape(@NonNull CustomOp op) { + return calculateOutputShape(op, null); + } + + @Override + public List calculateOutputShape(@NonNull CustomOp op, OpContext opContext) { val lc = op.opName().toLowerCase(); val hash = op.opHash(); val result = new ArrayList(); - if(op.numInputArguments() < 1 && op.getDescriptor().getNumInputs() != -2) { + int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); + if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); @@ -1690,10 +1763,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - - val inputBuffers = new PointerPointer<>(op.numInputArguments()); - val inputShapes = new PointerPointer<>(op.numInputArguments()); - val inputArgs = op.inputArguments(); + val inputBuffers = new PointerPointer<>(nIn); + val inputShapes = new PointerPointer<>(nIn); + val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments(); int cnt= 0; for (val in: inputArgs) { if (!in.isEmpty()) @@ -1703,76 +1775,95 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } - val iArgs = op.numIArguments() > 0 ? new LongPointer(op.numIArguments()) : null; + int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments(); + val iArgs = nIArgs > 0 ? new LongPointer(nIArgs) : null; cnt = 0; - val iArgs1 = op.iArgs(); - for (val i: iArgs1) - iArgs.put(cnt++, i); + if(opContext != null){ + for (val i: opContext.getIArguments()) + iArgs.put(cnt++, i); + } else { + for (val i: op.iArgs()) + iArgs.put(cnt++, i); + } - val tArgs = op.numTArguments() > 0 ? new DoublePointer(op.numTArguments()) : null; - val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null; + int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments(); + val tArgs = nTArgs > 0 ? new DoublePointer(nTArgs) : null; - val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments(); + val bArgs = nBArgs > 0 ? new BooleanPointer(nBArgs) : null; - cnt = 0; - val bArgs1 = op.bArgs(); - for (val b: bArgs1) + int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments(); + val dArgs = nDArgs > 0 ? new IntPointer(nDArgs) : null; + + cnt = 0; + if(opContext != null){ + for (val b: opContext.getBArguments()) bArgs.put(cnt++, b); - - cnt = 0; - val tArgs1 = op.tArgs(); - for (val t: tArgs1) - tArgs.put(cnt++, t); - - cnt = 0; - val dArgs1 = op.dArgs(); - for (val d: dArgs1) - dArgs.put(cnt++, d.toInt()); + } else { + for (val b: op.bArgs()) + bArgs.put(cnt++, b); + } - OpaqueShapeList ptrptr; - try { - ptrptr = loop.calculateOutputShapes2(null, - hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, - op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments()); + cnt = 0; + if(opContext != null){ + for (val b: opContext.getTArguments()) + tArgs.put(cnt++, b); + } else { + for (val b: op.tArgs()) + tArgs.put(cnt++, b); + } - if (loop.lastErrorCode() != 0) - throw new RuntimeException(loop.lastErrorMessage()); - } catch (Throwable t){ - StringBuilder sb = new StringBuilder(); - sb.append("Inputs: [("); - for( int i=0; i 0) - sb.append("), ("); - sb.append(Shape.shapeToStringShort(inputArgs.get(i))); - } - sb.append(")]"); - if(op instanceof DifferentialFunction && ((DifferentialFunction)op).getSameDiff() != null){ - appendSameDiffInfo(sb, (DifferentialFunction) op); - } + cnt = 0; + if(opContext != null){ + for (val b: opContext.getDArguments()) + dArgs.put(cnt++, b.toInt()); + } else { + for (val b: op.dArgs()) + dArgs.put(cnt++, b.toInt()); + } - log.error("Failed to calculate output shapes for op " + op.opName() + ". Attempted to execute with " + - String.valueOf(op.numInputArguments()) + " inputs, " + - String.valueOf(op.numOutputArguments()) + " outputs, "+ - String.valueOf(op.numTArguments()) + " targs and " + - String.valueOf(op.numIArguments()) + " iargs. " + - sb.toString() + - " - Please see above message (printed out from c++) for a possible cause of error."); - throw t; + + OpaqueShapeList ptrptr; + try { + ptrptr = loop.calculateOutputShapes2(null, + hash, inputBuffers, inputShapes, nIn, tArgs, + nTArgs, iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + } catch (Throwable t){ + StringBuilder sb = new StringBuilder(); + sb.append("Inputs: [("); + for( int i=0; i 0) + sb.append("), ("); + sb.append(Shape.shapeToStringShort(inputArgs.get(i))); } + sb.append(")]"); + if(op instanceof DifferentialFunction && ((DifferentialFunction)op).getSameDiff() != null){ + appendSameDiffInfo(sb, (DifferentialFunction) op); + } + + int nOut = opContext != null ? opContext.numOutputArguments() : op.numOutputArguments(); + log.error("Failed to calculate output shapes for op {}. Attempted to execute with {} inputs, {} outputs, " + + "{} targs, {} iargs, {} bargs and {} dargs. {} - Please see above message (printed out from c++) for a possible cause of error.", + op.opName(), nIn, nOut, nTArgs, nIArgs, nBArgs, nDArgs, sb.toString()); + throw t; + } if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); - if (ptrptr == null) - throw new RuntimeException(); + if (ptrptr == null) + throw new RuntimeException(); - for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ ) - result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer())); + for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ ) + result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer())); - loop.deleteShapeList(ptrptr); + loop.deleteShapeList(ptrptr); if(log.isTraceEnabled()){ String[] arr = new String[result.size()]; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 6c9633a41..4f228717a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -385,6 +385,7 @@ public class RandomOpValidation extends BaseOpValidation { @Test public void testUniformDtype(){ + Nd4j.getRandom().setSeed(12345); for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ SameDiff sd = SameDiff.create(); SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java new file mode 100644 index 000000000..7addd5098 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java @@ -0,0 +1,169 @@ +package org.nd4j.autodiff.samediff; + +import lombok.extern.slf4j.Slf4j; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; +import org.nd4j.imports.TFGraphs.TFGraphTestZooModels; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.AtomicBoolean; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +@Slf4j +public class SameDiffMultiThreadTests extends BaseND4JTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } + + @Test + public void testSimple() throws Exception { + + int nThreads = 4; + int nRuns = 1000; + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 10); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 10); + + SDVariable w1 = sd.var("w1", Nd4j.rand(DataType.FLOAT, 10, 10)); + SDVariable b1 = sd.var("b1", Nd4j.rand(DataType.FLOAT, 10)); + SDVariable w2 = sd.var("w2", Nd4j.rand(DataType.FLOAT, 10, 10)); + SDVariable b2 = sd.var("b2", Nd4j.rand(DataType.FLOAT, 10)); + SDVariable w3 = sd.var("w3", Nd4j.rand(DataType.FLOAT, 10, 10)); + SDVariable b3 = sd.var("b3", Nd4j.rand(DataType.FLOAT, 10)); + + SDVariable l1 = sd.nn.tanh(in.mmul(w1).add(b1)); + SDVariable l2 = sd.nn.sigmoid(l1.mmul(w2).add(b2)); + SDVariable l3 = sd.nn.softmax("out", l2.mmul(w3).add(b3)); + + SDVariable loss = sd.loss.logLoss("loss", label, l3); + + INDArray[] inputArrs = new INDArray[nThreads]; + INDArray[] expOut = new INDArray[nThreads]; + for( int i=0; i 2) + inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3); + else if(i == 1) + inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3); + else if(i == 2) + inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3); + + expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1"); + Nd4j.getExecutioner().commit(); + } + + AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads]; + AtomicInteger[] counters = new AtomicInteger[nThreads]; + Semaphore s = new Semaphore(nThreads); + CountDownLatch latch = new CountDownLatch(nThreads); + + doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch); + + s.release(nThreads); + latch.await(); + + for(int i=0; i op.z() } + def exec(op: Op, context: OpContext): INDArray = + Nd4j.getExecutioner.exec(op, context) + def exec(op: FilterOps): INDArray = { val retVal: INDArray = Nd4j.create(op.x.dataType(), op.x.shape().map(_.toLong): _*) for (i <- 0 until op.x().length().toInt) { @@ -408,6 +411,9 @@ class FunctionalOpExecutioner extends OpExecutioner { def calculateOutputShape(op: CustomOp): java.util.List[LongShapeDescriptor] = Nd4j.getExecutioner.calculateOutputShape(op) + def calculateOutputShape(op: CustomOp, ctx: OpContext): java.util.List[LongShapeDescriptor] = + Nd4j.getExecutioner.calculateOutputShape(op, ctx) + /** * Equivalent to calli */ From 5a34ccf3d43abae34ee9d94be178c4cce7385f90 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 20 Mar 2020 23:50:17 +1100 Subject: [PATCH 09/17] Remove printf in ones_as c++ op (#336) Signed-off-by: Alex Black --- libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp index dccebf8c9..32ce54300 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp @@ -38,7 +38,7 @@ namespace sd { auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); auto shape = sd::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); - nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str()); + //nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str()); return SHAPELIST(shape); } From 015147b71331d852a7741e06d079e847d620a0a3 Mon Sep 17 00:00:00 2001 From: Adam Gibson <1144306+agibsonccc@users.noreply.github.com> Date: Sat, 21 Mar 2020 17:30:26 +0900 Subject: [PATCH 10/17] Fix openblas linking issues (#340) * Fix cmake detection in msys * Revert windows change * Update to unix line endings * Fix linking issues --- nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index a964f6918..466bc6264 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -155,8 +155,8 @@ /${javacpp.platform.library.path}/ /${javacpp.platform.library.path}/lib/ - + /org/bytedeco/openblas/${javacpp.platform}/ + /org/bytedeco/openblas/${javacpp.platform}/lib/ From 69c92ca5ae517bcab4f0aad0eec92b128b01644d Mon Sep 17 00:00:00 2001 From: Oleh Date: Mon, 23 Mar 2020 06:28:31 +0200 Subject: [PATCH 11/17] Learning updaters for gradient (#335) * libnd4j raw implementation of sgd upader Signed-off-by: Oleg * libnd4j some corrections and simple test added Signed-off-by: Oleg * libnd4j some corrections after discussion Signed-off-by: Oleg * libnd4j integrate applyScalar Signed-off-by: Oleg * libnd4j raw implementation of rmsPropUpdater on cpu Signed-off-by: Oleg * libnd4j fix operations declaration Signed-off-by: Oleg * libnd4j rmsPropUpdater added, test cases for sgd, etc Signed-off-by: Oleg * libnd4j fixed several typos Signed-off-by: Oleg * libnd4j some fixes and improvements for rmsPropUpdater based on Java tests Signed-off-by: Oleg * libnd4j fixed cuda implementation, update tests and corrected behavior according java tests Signed-off-by: Oleg * libnd4j adaGrad updater added Signed-off-by: Oleg * libnd4j one minor fix for ada grad Signed-off-by: Oleg * libnd4j several more fixes for ada_grad Signed-off-by: Oleg * libnd4j nesterovs updater added Signed-off-by: Oleg * libnd4j fixed nesterovs updater behavior, several typos and rename file Signed-off-by: Oleg * libnd4j one minor typo Signed-off-by: Oleg * libnd4j ada max updater added Signed-off-by: Oleg * libnd4j fixed several typos in adaMax updater Signed-off-by: Oleg * libnd4j fixed several typos in adaMaxUpdater Signed-off-by: Oleg * libnd4j several fixes for adaMax, added Adam Updater Signed-off-by: Oleg * libnd4j adaDeltaUpdater added, minor fixes for adamUpdater Signed-off-by: Oleg * libnd4j several fixes for adaDeltaUpdater Signed-off-by: Oleg * libnd4j nadamUpdater added Signed-off-by: Oleg * libnd4j one more correction for nadam updater Signed-off-by: Oleg * libnd4j several fixes for nadam updater and added amsGradUpdater Signed-off-by: Oleg * libnd4j several typos fixed in amsGradUpdater Signed-off-by: Oleg * libnd4j some corrections and added f order support rmsProp updater Signed-off-by: Oleg * libnd4j added support of f order for all updaters and modify tests for testing in place Signed-off-by: Oleg * libnd4j fixed issues for updates when not in place mode used, added tests for f order Signed-off-by: Oleg * libnd4j added input shape checks Signed-off-by: Oleg * libnd4j some corrections for different cases handling Signed-off-by: Oleg * libnd4j some code clean up and optimize per request Signed-off-by: Oleg * libnd4j updaters refactoring after review Signed-off-by: Oleg * SgdUpdater wrapper Signed-off-by: raver119 * first test Signed-off-by: raver119 * RmsPropUpdater added Signed-off-by: raver119 * NadamUpdater + NesterovsUpdater Signed-off-by: raver119 * AmsGradUpdater Signed-off-by: raver119 * AdamUpdater added Signed-off-by: raver119 * AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater Signed-off-by: raver119 * AdaGradUpdater test added Signed-off-by: raver119 * libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc Signed-off-by: Oleg * libnd4j next step to split operations implementation into separate files Signed-off-by: Oleg * libnd4j merge master and minor corrections Signed-off-by: Oleg * libnd4j revert some changes of split implementation Signed-off-by: Oleg * libnd4j forgot to add header file Signed-off-by: Oleg * public default constructors Signed-off-by: raver119 * ImportClassMapping updated Signed-off-by: raver119 Co-authored-by: raver119 --- .../include/ops/declarable/CustomOperations.h | 1 + .../generic/updaters/adaDeltaUpdater.cpp | 81 ++ .../generic/updaters/adaGradUpdater.cpp | 77 ++ .../generic/updaters/adaMaxUpdater.cpp | 93 ++ .../generic/updaters/adamUpdater.cpp | 92 ++ .../generic/updaters/amsGradUpdater.cpp | 98 ++ .../generic/updaters/nadamUpdater.cpp | 92 ++ .../generic/updaters/nesterovsUpdater.cpp | 75 + .../generic/updaters/rmsPropUpdater.cpp | 80 ++ .../generic/updaters/sgdUpdater.cpp | 61 + .../include/ops/declarable/headers/updaters.h | 210 +++ .../helpers/cpu/updaterAdaDelta.cpp | 108 ++ .../declarable/helpers/cpu/updaterAdaGrad.cpp | 91 ++ .../declarable/helpers/cpu/updaterAdaMax.cpp | 113 ++ .../declarable/helpers/cpu/updaterAdam.cpp | 113 ++ .../declarable/helpers/cpu/updaterAmsGrad.cpp | 126 ++ .../declarable/helpers/cpu/updaterNadam.cpp | 116 ++ .../helpers/cpu/updaterNesterovs.cpp | 91 ++ .../declarable/helpers/cpu/updaterRmsProp.cpp | 91 ++ .../helpers/cuda/updaterAdaDelta.cu | 129 ++ .../declarable/helpers/cuda/updaterAdaGrad.cu | 117 ++ .../declarable/helpers/cuda/updaterAdaMax.cu | 142 ++ .../declarable/helpers/cuda/updaterAdam.cu | 139 ++ .../declarable/helpers/cuda/updaterAmsGrad.cu | 152 ++ .../declarable/helpers/cuda/updaterNadam.cu | 137 ++ .../helpers/cuda/updaterNesterovs.cu | 117 ++ .../declarable/helpers/cuda/updaterRmsProp.cu | 121 ++ .../ops/declarable/helpers/updatersHelpers.h | 44 + .../layers_tests/DeclarableOpsTests18.cpp | 1229 +++++++++++++++++ .../converters/ImportClassMapping.java | 9 + .../ops/impl/updaters/AdaDeltaUpdater.java | 47 + .../api/ops/impl/updaters/AdaGradUpdater.java | 47 + .../api/ops/impl/updaters/AdaMaxUpdater.java | 48 + .../api/ops/impl/updaters/AdamUpdater.java | 48 + .../api/ops/impl/updaters/AmsGradUpdater.java | 48 + .../api/ops/impl/updaters/NadamUpdater.java | 48 + .../ops/impl/updaters/NesterovsUpdater.java | 47 + .../api/ops/impl/updaters/RmsPropUpdater.java | 47 + .../api/ops/impl/updaters/SgdUpdater.java | 47 + .../java/org/nd4j/nativeblas/Nd4jCuda.java | 1 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 1 + .../linalg/learning/UpdaterValidation.java | 74 +- 42 files changed, 4646 insertions(+), 2 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/headers/updaters.h create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu create mode 100644 libnd4j/include/ops/declarable/helpers/updatersHelpers.h create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 1a1624c08..f98deb784 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp new file mode 100644 index 000000000..bab205543 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateMsg = INPUT_VARIABLE(1); + const auto initStateMsdx = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateMsg = OUTPUT_VARIABLE(1); + auto stateMsdx = OUTPUT_VARIABLE(2); + + if (gradient->isEmpty() || initStateMsg->isEmpty() || initStateMsdx->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, "ADA_DELTA UPDATER OP: input state Msg must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsg->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, "ADA_DELTA UPDATER OP: input state Msdx must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsdx->getShapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 2 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); + + double dRho, dEpsilon; + + if (block.width() > 3) { + const auto rho = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(rho->isScalar(), 0, "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead got rank %i!", rho->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dRho = rho->e(0); + dEpsilon = epsilon->e(0); + } + else { + dRho = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, dEpsilon); + return Status::OK(); + } + + DECLARE_TYPES(ada_delta_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp new file mode 100644 index 000000000..a7a92b410 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateH = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, "ADA_GRAD UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); + + + bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); + + double dLr, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto epsilon = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, *stateH, dLr, dEpsilon); + return Status::OK(); + } + + DECLARE_TYPES(ada_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp new file mode 100644 index 000000000..4e34c24f6 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp @@ -0,0 +1,93 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADA_MAX UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADA_MAX UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + int iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(ada_max_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp new file mode 100644 index 000000000..a696d2388 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(adam_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp new file mode 100644 index 000000000..bc0f4beac --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + const auto initStateH = INPUT_VARIABLE(3); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + auto stateH = OUTPUT_VARIABLE(3); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || initStateH->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "AMSGRAD UPDATER OP: input state Msg must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient!," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateH->getShapeInfo()).c_str()); + + bool bParamsSupply = 8 == block.width() || 4 == block.getTArguments()->size(); + + auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 4) { + const auto lr = INPUT_VARIABLE(4); + const auto beta1 = INPUT_VARIABLE(5); + const auto beta2 = INPUT_VARIABLE(6); + const auto epsilon = INPUT_VARIABLE(7); + + REQUIRE_TRUE(lr->isScalar(), 0, "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, *initStateM, *initStateH, + *update, *stateV, *stateM, *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(ams_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp new file mode 100644 index 000000000..c6af0686b --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "NADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "NADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + auto nIteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "NADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "NADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, *initStateM, *update, *stateV, *stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration); + return Status::OK(); + } + + DECLARE_TYPES(nadam_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp new file mode 100644 index 000000000..c77abd448 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, "NESTEROVS UPDATER OP: input state Msg must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); + + bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); + + double dLr, dMomentum; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto momentum = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(momentum->isScalar(), 0, "NESTEROVS UPDATER OP: Momentum has to be a scalar, but instead got rank %i!", momentum->rankOf()); + + dLr = lr->e(0); + dMomentum = momentum->e(0); + } + else { + dLr = T_ARG(0); + dMomentum = T_ARG(1); + } + helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, *update, *stateV, dLr, dMomentum); + return Status::OK(); + } + + DECLARE_TYPES(nesterovs_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp new file mode 100644 index 000000000..1ca318e26 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateG = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, "RMS_PROB UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 3 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!"); + + double dLr, dRmsDecay, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto rmsDecay = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(lr->isScalar(), 0, "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(rmsDecay->isScalar(), 0, "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but instead got rank %i!", rmsDecay->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dRmsDecay = rmsDecay->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dRmsDecay = T_ARG(1); + dEpsilon = T_ARG(2); + } + + helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, *stateG, dLr, dRmsDecay, dEpsilon); + return Status::OK(); + } + + DECLARE_TYPES(rms_prop_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp new file mode 100644 index 000000000..491d7b53e --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) { + + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + bool bLearningRate = 2 == block.width() || 1 == block.getTArguments()->size(); + + REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!"); + + if (block.width() > 1) { + const auto lr = INPUT_VARIABLE(1); + REQUIRE_TRUE(lr->isScalar(), 0, "SGD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + + input->applyScalarArr(scalar::Multiply, *lr, *output); + } + else { + input->applyScalar(scalar::Multiply, T_ARG(0), *output); + } + + return Status::OK(); + } + + DECLARE_TYPES(sgd_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/headers/updaters.h b/libnd4j/include/ops/declarable/headers/updaters.h new file mode 100644 index 000000000..dc08ff1f2 --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/updaters.h @@ -0,0 +1,210 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + + +#ifndef LIBND4J_HEADERS_UPDATERS_H +#define LIBND4J_HEADERS_UPDATERS_H + +#include +#include +#include +#include +#include + + +namespace sd { + namespace ops { + + + /** + * SGD updater + * Input arrays: + * 0 - input array with gradients. + * Optional: + * 1 - scalar learning rate value + * Optional: + * T args + * 0 - scalar learning rate value + */ +#if NOT_EXCLUDED(OP_sgd_updater) + DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0); +#endif + + /** + * RmsPropUpdater updater + * Input arrays: + * 0 - input array with gradients. + * 1 - Initial state + * Optional: + * 2 - scalar learning rate value + * 3 - scalar rms decay + * 4 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - scalar rms decay + * 2 - epsilon + */ +#if NOT_EXCLUDED(OP_rms_prop_updater) + DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0); +#endif + // AdaGrad + /* Input arrays : + * 0 - input array with gradients. + * 1 - historical grad state + * Optional : + * 2 - scalar learning rate value + * 3 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - epsilon + */ +#if NOT_EXCLUDED(OP_ada_grad_updater) + DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0); +#endif + // AdaMax + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_ada_max_updater) + DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0); +#endif + // Nesterov's momentum + /* Input arrays : + * 0 - input array with gradients. + * 1 - V grad state + * Optional : + * 2 - scalar learning rate value + * 3 - scalar momentum value + * Optional: + * T args + * 0 - learning rate value + * 1 - momentum value + */ +#if NOT_EXCLUDED(OP_nesterovs_updater) + DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0); +#endif + // Adam + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_adam_updater) + DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); +#endif + // AdaDelta + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - rho value + * 6 - epsilon + * Optional: + * T args + * 0 - rho + * 1 - epsilon + */ +#if NOT_EXCLUDED(OP_ada_delta_updater) + DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0); +#endif + // Nadam + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_nadam_updater) + DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0); +#endif + // AmsGrad + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V - sqrd gradients + * 2 - gradient state M - moving avg + * 3 - gradient state H - max + * Optional : + * 4 - scalar learning rate value + * 5 - beta 1 value + * 6 - beta 2 value + * 7 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_ams_grad_updater) + DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0); +#endif +} +} + +#endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp new file mode 100644 index 000000000..e80018348 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp @@ -0,0 +1,108 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaDeltaUpdater_(const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { + + const T* grad = gradient.bufferAsT(); + const T* initMsg = initStateMsg.bufferAsT(); + const T* initMsdx = initStateMsdx.bufferAsT(); + + T* up = update.bufferAsT(); + T* stMsg = stateMsg.bufferAsT(); + T* stMsdx = stateMsdx.bufferAsT(); + + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + const T rhoT = (1 - rho); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateMsg.ews() && 1 == initStateMsg.ews() && 1 == stateMsdx.ews() && 1 == initStateMsdx.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateMsdx.ordering() && + stateMsdx.ordering() == initStateMsdx.ordering() && + stateMsdx.ordering() == initStateMsg.ordering() && stateMsg.ordering() == initStateMsg.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT; + + up[i] = grad[i] * (sd::math::nd4j_sqrt(initMsdx[i] + epsilon) / sd::math::nd4j_sqrt(stMsg[i] + epsilon)); + + stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsg.getShapeInfo()); + bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsg.getShapeInfo()); + bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsdx.getShapeInfo()); + bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsdx.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < gradient.lengthOf(); i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(initStateMsg.getShapeInfo(), coords); + const auto stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stateMsg.getShapeInfo(), coords); + const auto initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(initStateMsdx.getShapeInfo(), coords); + const auto stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stateMsdx.getShapeInfo(), coords); + + + stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; + + up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); + + stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, (gradient, initStateMsg, initStateMsdx, update, stateMsg, stateMsdx, dRho, dEpsilon), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp new file mode 100644 index 000000000..280597d31 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { + + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + st[i] = init[i] + grad[i] * grad[i]; + up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo()); + bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords); + const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords); + + st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset]; + up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp new file mode 100644 index 000000000..ae986f901 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + //m = B_1 * m + (1-B_1)*grad + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + //u = max(B_2 * u, |grad|) + stU[i] = sd::math::nd4j_max((beta2 * initU[i]), sd::math::nd4j_abs(grad[i])) + 1e-32; + + up[i] = stM[i] * epsilonT / stU[i]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords); + const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords); + const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + + //m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + //u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; + + up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp new file mode 100644 index 000000000..b8eab1e6f --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2); + + up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords); + const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords); + const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); + + up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp new file mode 100644 index 000000000..686c22cbe --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, + NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + const T* initH = initStateH.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + T* stH = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + T epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && + 1 == stateV.ews() && 1 == initStateV.ews() && 1 == stateH.ews() && 1 == initStateH.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering() && + stateM.ordering() == initStateH.ordering() && stateH.ordering() == initStateH.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * mbeta1; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + stH[i] = sd::math::nd4j_max(initH[i], stV[i]); + + up[i] = epsilonT * stM[i] / (sd::math::nd4j_sqrt(stH[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + bool bXInHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateH.getShapeInfo()); + bool bXStHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords); + const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords); + const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + const auto initHOffset = bXInHSame ? xOffset : shape::getOffset(initStateH.getShapeInfo(), coords); + const auto stHOffset = bXStHSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + + up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, + NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdater_, (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp new file mode 100644 index 000000000..82ade0f16 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateV.ews() && 1 == initStateV.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + auto oneMinusBeta1Grad = grad[i] * mbeta1; + + stM[i] = beta1 * initM[i] + oneMinusBeta1Grad; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + + up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords); + const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords); + const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + + up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stVOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, (gradient, initStateV, initStateM, update, stateV, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp new file mode 100644 index 000000000..82e21ace7 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { + + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateV.bufferAsT(); + + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + const T momentumT = (-momentum - 1); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateV.ordering() && stateV.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + T prevState = momentum * init[i]; + st[i] = prevState - lr * grad[i]; + up[i] = prevState + momentumT * st[i]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo()); + bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords); + const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords); + + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, (gradient, initState, update, stateV, dLr, dMomentum), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp new file mode 100644 index 000000000..a0b9f731e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +template +static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, const double dEpsilon) { + + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateG.bufferAsT(); + + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateG.ordering() && stateG.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay) ; + up[i] = (lr * grad[i]) / ( math::nd4j_sqrt(st[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo()); + bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateG.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords); + const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateG.getShapeInfo(), coords); + + st[stOffset] = init[initOffset] * rmsDecay + grad[xOffset] * grad[xOffset] * (1 - rmsDecay) ; + up[zOffset] = (lr * grad[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdater_, (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), FLOAT_TYPES); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu new file mode 100644 index 000000000..33272ff57 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu @@ -0,0 +1,129 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaDeltaUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, const Nd4jLong* inMsgShapeInfo, + const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, + const Nd4jLong* stMsgShapeInfo, void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) { + + const auto grad = reinterpret_cast(vx); + const auto initMsg= reinterpret_cast(vinMsg); + const auto initMsdx = reinterpret_cast(vinMsdx); + + auto up = reinterpret_cast(vz); + auto stMsg = reinterpret_cast(vstMsg); + auto stMsdx = reinterpret_cast(vstMsdx); + + __shared__ Nd4jLong xLen; + __shared__ T rhoT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, bXInMsdxSame, bXStMsdxSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + rhoT = (1 - rho); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stMsgShapeInfo) && 1 == shape::elementWiseStride(inMsgShapeInfo) && + 1 == shape::elementWiseStride(stMsdxShapeInfo) && 1 == shape::elementWiseStride(inMsdxShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) && + shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) && + shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo); + bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo); + bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo); + bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, stMsgOffset = i, stMsdxOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords); + stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords); + initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords); + stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords); + } + + stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; + + up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); + + stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaDeltaUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, const Nd4jLong* stMsgShapeInfo, + void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) { + + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + + adaDeltaUpdaterCuda << > > (vx, xShapeInfo, vinMsg, inMsgShapeInfo, + vinMsdx, inMsdxShapeInfo, vz, zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, epsilon); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { + + PointersManager manager(context, "adaDeltaUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateMsg.getSpecialBuffer(), initStateMsg.getSpecialShapeInfo(), initStateMsdx.getSpecialBuffer(), initStateMsdx.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(),stateMsg.getSpecialBuffer(), stateMsg.getSpecialShapeInfo(), + stateMsdx.getSpecialBuffer(), stateMsdx.getSpecialShapeInfo(), dRho, dEpsilon), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu new file mode 100644 index 000000000..f0e77826d --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const T lr, const T epsilon) { + + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + __shared__ Nd4jLong xLen; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); + } + + st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset]; + up[zOffset] = (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dEpsilon) { + + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + adaGradUpdaterCuda << > > (vx, xShapeInfo, vin, inShapeInfo, + vz, zShapeInfo, vst, stShapeInfo, lr, epsilon); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { + + PointersManager manager(context, "adaGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateH }, { &gradient, &initState }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initState.getSpecialBuffer(), initState.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dEpsilon), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateH }, { &gradient, &initState }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu new file mode 100644 index 000000000..514440304 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu @@ -0,0 +1,142 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaMaxUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T beta1T, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + beta1T = sd::math::nd4j_pow(beta1, (iteration + 1) ); + + epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stmShapeInfo) && + shape::order(xShapeInfo) == shape::order(inmShapeInfo) && shape::order(xShapeInfo) == shape::order(invShapeInfo) && + shape::order(xShapeInfo) == shape::order(stvShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + //m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + //u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max( (beta2* initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; + + up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaMaxUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, + const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + adaMaxUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, + zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "adaMaxUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), initStateU.getSpecialBuffer(), + initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), + stateU.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), + dLr, dBeta1, dBeta2, dEpsilon, nIteration ), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu new file mode 100644 index 000000000..e23f4a5ca --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu @@ -0,0 +1,139 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV, + const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); + + up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + adamUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "adamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateU.getSpecialBuffer(), initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), stateU.getSpecialShapeInfo(), + stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + + NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu new file mode 100644 index 000000000..d24c83f17 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu @@ -0,0 +1,152 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void amsGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, const void* vinh, const Nd4jLong* inhShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + const auto initH = reinterpret_cast(vinh); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + auto stH = reinterpret_cast(vstH); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1, mbeta2, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame, bXInHSame, bXStHSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo) && + 1 == shape::elementWiseStride(sthShapeInfo) && 1 == shape::elementWiseStride(inhShapeInfo); + + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo) && shape::order(invShapeInfo) == shape::order(sthShapeInfo) && + shape::order(sthShapeInfo) == shape::order(inhShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo); + bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + initVOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initHOffset = bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords); + stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords); + } + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + + up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void amsGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + const void* vinh, const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + void* vstH, const Nd4jLong* sthShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + amsGradUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vinh, inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, + NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "amsGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateV, &stateM, &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + initStateH.getSpecialBuffer(), initStateH.getSpecialShapeInfo(), update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), + stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateV, &stateM , &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); + + manager.synchronize(); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu new file mode 100644 index 000000000..2ac1ec99b --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu @@ -0,0 +1,137 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1T, mbeta1, mbeta2; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stUOffset] = beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + + up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stUOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void nadamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + nadamUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "nadamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), + stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); + + manager.synchronize(); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu new file mode 100644 index 000000000..73616a5cd --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + + +/////////////////////////////////////////////////////////////////// +template +__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) { + + const auto grad = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ T momentumT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + momentumT = (-momentum - 1); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); + } + + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dMomentum) { + + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + nesterovsUpdaterCuda << > > (vx, xShapeInfo, vin, inShapeInfo, + vz, zShapeInfo, vst, stShapeInfo, lr, momentum); +} + +/////////////////////////////////////////////////////////////////// +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { + + PointersManager manager(context, "nesterovsUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, + context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initState.getSpecialBuffer(), initState.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), dLr, dMomentum), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu new file mode 100644 index 000000000..de0a5dba1 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu @@ -0,0 +1,121 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const T lr, const T rmsDecay, const T epsilon) { + + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); + + bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); + } + + st[stOffset] = init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay) ; + up[zOffset] = (lr * x[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void rmsPropUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dRmsDecay, const double dEpsilon) { + + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); + + rmsPropUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, + vz, zShapeInfo, vst, stShapeInfo, lr, rmsDecay, epsilon); +} + +/////////////////////////////////////////////////////////////////// +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, const double dEpsilon) { + + PointersManager manager(context, "rmsPropUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState }); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, + context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initState.getSpecialBuffer(), initState.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateG.getSpecialBuffer(), stateG.getSpecialShapeInfo(), + dLr, dRmsDecay, dEpsilon ), FLOAT_TYPES); + + NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState}); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h new file mode 100644 index 000000000..5bd89b487 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#ifndef LIBND4J_UPDATER_RMS_PROM_H +#define LIBND4J_UPDATER_RMS_PROM_H + +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + + void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, const double dLr, const double dRmsDecay, const double dEpsilon); + void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon); + void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double bMomentum); + void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon); + void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + +} +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index f8de783c9..b1cafa073 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -187,3 +187,1232 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.equalsTo(exp)); } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { + + NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.001f); + + NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + sd::ops::sgd_updater op; + + Nd4jStatus status = op.execute({ &gradient, &lr }, { &gradient }, {}, { }); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd2) { + + NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + sd::ops::sgd_updater op; + + Nd4jStatus status = op.execute({ &gradient }, { &gradient }, { 0.001f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd3) { + + NDArray gradientC('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + + NDArray updateC('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + NDArray gradient('f', { 1, 5 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + gradient.assign(gradientC); + update.assign(updateC); + + sd::ops::sgd_updater op; + + auto results = op.evaluate({ &gradient }, { 0.001f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm1) { + + NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto decay = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init, &lr, &decay, &epsilon }, { &grad0, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0('c', { 1, 5 }, { 0.4472121903197142, 0.4472095514452829, 0.4472135169488324, 0.44721352981195367, 0.44721349127249754 }, DataType::FLOAT32); + NDArray stateG0('c', { 1, 5 }, { 0.00164065126484513, 0.00055124687044416, 0.03816546608068996, 0.04711672627124962, 0.02749591463177582 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + + NDArray grad1('c', { 1, 5 }, { 0.0139725673943758, 0.19333727657794952, 0.9288347363471985, 0.9253600239753723, 0.3578299283981323 }, DataType::FLOAT32); + status = op.execute({ &grad1, &init, &lr, &decay, &epsilon }, { &grad1, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.03528177364993147, 0.3952537075263024, 0.32964378302079766, 0.31269398966616074, 0.1984174163852542 }, DataType::FLOAT32); + NDArray stateG1('c', { 1, 5 }, { 0.00156838033358239, 0.00239264965265088, 0.07939389114891399, 0.08757544865627226, 0.03252323178305766 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + NDArray grad2('c', { 1, 5 }, { 0.5442887544631958, 0.5386605262756348, 0.884294331073761, 0.15599730610847473, 0.7259345054626465 }, DataType::FLOAT32); + status = op.execute({ &grad2, &init, &lr, &decay, &epsilon }, { &grad2, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.4262874753567082, 0.41582357367557454, 0.2613066321005825, 0.05369221235564697, 0.3034061716240995 }, DataType::FLOAT32); + NDArray stateG2('c', { 1, 5 }, { 0.01630247372865814, 0.01678077529839554, 0.11452301978992785, 0.0844134341991137, 0.05724611550496966 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm2) { + + NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); + NDArray stateG0('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); + NDArray stateG1('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); + NDArray stateG2('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray init('f', { 1, 5 }, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::rms_prop_updater op; + auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateG0C('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); + NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); + NDArray stateG1C('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); + NDArray stateG2C('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaGrad1) { + + // need Java test + + NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::ada_grad_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init, &lr, &epsilon }, { &grad0, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs1) { + + NDArray grad0('c', { 1, 5 }, { 0.6877592206001282, 0.7830561399459839, 0.7647699117660522, 0.6183066964149475, 0.3303879499435425 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init }, { &grad0, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.13067425191402435, 0.14878066658973696, 0.14530628323554992, 0.11747827231884002, 0.06277371048927306 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { -0.06877592206001282, -0.0783056139945984, -0.07647699117660522, -0.06183066964149475, -0.03303879499435425 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + NDArray grad1('c', { 1, 5 }, { 0.3676236569881439, 0.07645636051893234, 0.45949840545654297, 0.6335387825965881, 0.2953402101993561 }, DataType::FLOAT32); + status = op.execute({ &grad1, &init }, { &grad1, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.12555699169635773, 0.07795425583422186, 0.14925105988979342, 0.17045521110296247, 0.08287606388330458 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { -0.09866069555282593, -0.0781206886470318, -0.11477913260459902, -0.11900148093700408, -0.05926893651485443 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + NDArray grad2('c', { 1, 5 }, { 0.9874004125595093, 0.41817641258239746, 0.16838215291500092, 0.00803728867322206, 0.37015461921691895 }, DataType::FLOAT32); + status = op.execute({ &grad2, &init }, { &grad2, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.26752124178409575, 0.1427312761947513, 0.12496370646357537, 0.09791828440688549, 0.11833721622824667 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { -0.18753466725349427, -0.11212626104056837, -0.12013943463563921, -0.10790506171062587, -0.09035750478506088 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs2) { + + NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto momentum = NDArrayFactory::create(0.9f); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray init('f', { 1, 5 }, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::nesterovs_updater op; + auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.9f }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateG0C('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); + NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); + NDArray stateG1C('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); + NDArray stateG2C('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray stateU('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax2) { + + NDArray grad0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray stateU0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.00538735911250114, 0.09700437784194944, 0.08912011384963987, 0.08891847729682921, 0.01882378011941909 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.6400517821311951, 0.3779360353946686, 0.35128724575042725, 0.6554615497589111, 0.8420050740242004 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00107575360832691, 0.00129089809294599, 0.00129546826560191, 0.00163878765669416, 0.00120120308808246 }, DataType::FLOAT32); + NDArray stateU1('c', { 1, 5 }, { 0.6400517821311951, 0.9690737346410752, 0.8903099373579025, 0.888295588195324, 0.8420050740242004 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.06885380141437052, 0.12509754359722136, 0.11533682703971859, 0.1455727845430374, 0.10114190950989721 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.5984494686126709, 0.05978915095329285, 0.5749519467353821, 0.2804091274738312, 0.0192152876406908 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00190508497658779, 0.00122473022928962, 0.00181352349370876, 0.00179237223044249, 0.00110500865710834 }, DataType::FLOAT32); + NDArray stateU2('c', { 1, 5 }, { 0.6394117303490638, 0.9681046609064341, 0.8894196274205446, 0.8874072926071286, 0.8411630689501762 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.12181336813420054, 0.11856670433282851, 0.16129833900928492, 0.15905641883611676, 0.09294924732297657 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::ada_max_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateM.assign(stateM1C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateM.assign(stateM2C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::adam_updater op; + + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam2) { + + NDArray grad0('c', { 1, 5 }, { 0.7124611735343933, 0.7283763289451599, 0.8196553587913513, 0.9501070976257324, 0.2654055953025818 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::adam_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999955614757, 0.00099999956584582, 0.00099999961419438, 0.0009999996671663, 0.00099999880851273 }, DataType::FLOAT32); + NDArray stateU0('c', { 1, 5 }, { 0.00050760092379401, 0.00053053207656763, 0.00067183490719538, 0.00090270349695879, 0.00007044013001792 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.07124611735343932, 0.07283763289451597, 0.08196553587913512, 0.09501070976257323, 0.02654055953025817 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.4374369978904724, 0.11488933861255646, 0.6765823364257812, 0.7659900188446045, 0.04410457238554955 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00129067017716555, 0.00104532555849556, 0.00133106720937621, 0.00132869584719374, 0.00105226561254395 }, DataType::FLOAT32); + NDArray stateU1('c', { 1, 5 }, { 0.00069844444999364, 0.00054320110461789, 0.00112892673025155, 0.00148854150243139, 0.00007231490319321 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.10786520540714262, 0.07704280346632002, 0.14142721593379973, 0.16210864067077635, 0.02829696081578731 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.496029257774353, 0.11621368676424026, 0.9112075567245483, 0.5717480182647705, 0.5975669026374817 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00150986322036664, 0.00108559662275258, 0.00156079502787382, 0.00150778241516558, 0.00130066803775601 }, DataType::FLOAT32); + NDArray stateU2('c', { 1, 5 }, { 0.00094379103011182, 0.00055616352450461, 0.00195809701495322, 0.00181394875731865, 0.00042932879141777 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.14668161064386365, 0.08095989179611204, 0.21840525001287456, 0.20307257843017573, 0.08522395499795674 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::adam_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); + NDArray stateMsg0('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); + NDArray stateMsdx0('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); + NDArray stateMsg1('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); + NDArray stateMsdx1('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); + NDArray stateMsg2('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); + NDArray stateMsdx2('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta2) { + + NDArray grad0('c', { 1, 5 }, { 0.22060230374336243, 0.10593396425247192, 0.9027279019355774, 0.831809401512146, 0.2733047902584076 }, DataType::FLOAT32); + NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto rho = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.0e-6); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initMsg, &initMsdx, &rho, &epsilon }, { &grad0, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.0044712172817412, 0.00446815612502933, 0.00447208107763182, 0.004472071321461, 0.00447153735969189 }, DataType::FLOAT32); + NDArray stateMsg0('c', { 1, 5 }, { 0.00243326882084394, 0.0005611002391122, 0.04074588324665051, 0.03459534402219976, 0.00373477541890961 }, DataType::FLOAT32); + NDArray stateMsdx0('c', { 1, 5 }, { 0.00000099958919903, 0.00000099822095788, 0.00000099997545825, 0.00000099997109521, 0.00000099973231796 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + NDArray grad1('c', { 1, 5 }, { 0.6351608633995056, 0.21878601610660553, 0.6470938920974731, 0.3742971122264862, 0.9453978538513184 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initMsg, &initMsdx, &rho, &epsilon }, { &grad1, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00598985959779411, 0.00571609509028959, 0.00374704195122062, 0.00265092283150538, 0.00608704322078556 }, DataType::FLOAT32); + NDArray stateMsg1('c', { 1, 5 }, { 0.02248307149952203, 0.00292641126934659, 0.05964511434381081, 0.03987049323214412, 0.0482368917512981 }, DataType::FLOAT32); + NDArray stateMsdx1('c', { 1, 5 }, { 0.00000274353063914, 0.00000258199706405, 0.00000165199285454, 0.00000130134213338, 0.00000280235046064 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + NDArray grad2('c', { 1, 5 }, { 0.8484492301940918, 0.9634076952934265, 0.6676893830299377, 0.4450211524963379, 0.32364124059677124 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initMsg, &initMsdx, &rho, &epsilon }, { &grad2, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00685468722145889, 0.00822128238053265, 0.00386965914609878, 0.00308849888680941, 0.00279277397245112 }, DataType::FLOAT32); + NDArray stateMsg2('c', { 1, 5 }, { 0.05735222273539331, 0.04918781007340889, 0.07895331423716523, 0.04777915987899536, 0.05106222979448406 }, DataType::FLOAT32); + NDArray stateMsdx2('c', { 1, 5 }, { 0.00000495569095238, 0.00000583237140987, 0.00000231810630717, 0.0000017132162954, 0.00000305221226067 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msg + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msdx + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initMsg('f', { 1, 5 }, DataType::FLOAT32); + NDArray initMsdx('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initMsg.assign(initVC); + initMsdx.assign(initMC); + + sd::ops::ada_delta_updater op; + auto results = op.evaluate({ &grad, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); + NDArray stateMsg('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); + NDArray stateMsdx('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateMsg.assign(stateV0C); + stateMsdx.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, results.at(1), results.at(2) }, { 0.95, 1.0e-6 }, { }); + + NDArray update1C('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); + + NDArray stateV1C('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); + + update.assign(update1C); + stateMsg.assign(stateV1C); + stateMsdx.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateMsg, &stateMsdx }, { 0.95f, 1.0e-6 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); + + update.assign(update2C); + stateMsg.assign(stateV2C); + stateMsdx.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::nadam_updater op; + + Nd4jStatus status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam2) { + + NDArray grad0('c', { 1, 5 }, { 0.8047558665275574, 0.9653639197349548, 0.31240877509117126, 0.9530212879180908, 0.01295729912817478 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::nadam_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.06008325193356386, 0.0600832558615088, 0.06008321472550684, 0.06008325560661022, 0.0600818092240132 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.00064763200471052, 0.00093192749752604, 0.00009759924275397, 0.00090824957522506, 0.0000001678916007 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.08047558665275573, 0.09653639197349546, 0.03124087750911712, 0.09530212879180906, 0.00129572991281748 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.9839006662368774, 0.8964805603027344, 0.3631269931793213, 0.00931886397302151, 0.6320028901100159 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.06273730114378717, 0.0596708938019245, 0.06226533928512862, 0.02621380498466489, 0.06059567064824535 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00161504489372718, 0.00173467296502922, 0.00022936285668667, 0.00090742816687558, 0.0003995953768165 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.17081809461116787, 0.17653080880641933, 0.06442948907613753, 0.08670380230993031, 0.06436644593253729 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.7712154984474182, 0.1282273381948471, 0.7019220590591431, 0.8883536458015442, 0.33057701587677 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.06062658222261493, 0.04001212712739213, 0.06906390273197544, 0.05804376499107734, 0.05097529565845974 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.00220820319387896, 0.00174938054232472, 0.00072182807082381, 0.0016956929387176, 0.00050847694486568 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2308578349947929, 0.1717004617452621, 0.12817874607443808, 0.16686878665909166, 0.09098750292696056 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::nadam_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateH0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateH1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateH2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad2) { + + NDArray grad0('c', { 1, 5 }, { 0.5730348229408264, 0.04330538213253021, 0.249028742313385, 0.6514443755149841, 0.7017051577568054 }, DataType::FLOAT32); + NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999944815292, 0.00099999269777932, 0.00099999873015716, 0.00099999951457465, 0.00099999954934402 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.0004923901284225 }, DataType::FLOAT32); + NDArray stateH0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.00049239012842255 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.05730348229408263, 0.00433053821325302, 0.0249028742313385, 0.0651444375514984, 0.07017051577568052 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.6404328346252441, 0.9432603120803833, 0.45608729124069214, 0.9097326993942261, 0.748093843460083 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134565543815267, 0.00104022434054697, 0.00130914539820157, 0.00133725290576052, 0.0013453914974122 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); + NDArray stateH1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.11561641752719877, 0.09822351559996603, 0.06802131593227385, 0.14960326373577115, 0.13796284854412078 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.46250319480895996, 0.09698919206857681, 0.21754667162895203, 0.46824514865875244, 0.6005083918571472 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00154098993679222, 0.00103399135000281, 0.00147364850040774, 0.00149693641196572, 0.00155078467854623 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); + NDArray stateH2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.1503050952553749, 0.09810008324682712, 0.08297385150194167, 0.1814674522280693, 0.1842174028754234 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initHC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + NDArray initH('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + initH.assign(initHC); + + sd::ops::ams_grad_updater op; + auto results = op.evaluate({ &grad, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateH0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateH('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + stateH.assign(stateH0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + NDArray stateH1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + stateH.assign(stateH1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + + NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + NDArray stateH2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + stateH.assign(stateH2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 327e3c52e..ebe27bd85 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -43,6 +43,15 @@ public class ImportClassMapping { private static final List> fnClasses = Arrays.>asList( org.nd4j.linalg.api.ops.DynamicCustomOp.class, org.nd4j.linalg.api.ops.NoOp.class, + org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater.class, org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java new file mode 100644 index 000000000..db87ad5e4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdaDeltaUpdater extends DynamicCustomOp { + + public AdaDeltaUpdater() { + // + } + + public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, double rho, double epsilon) { + this(gradients, stateMsg, stateMsdx, gradients, stateMsg, stateMsdx, rho, epsilon); + } + + public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, @NonNull INDArray updates, @NonNull INDArray updatedStateMsg, @NonNull INDArray updatedStateMsdx, double rho, double epsilon) { + addInputArgument(gradients, stateMsg, stateMsdx); + addOutputArgument(updates, updatedStateMsg, updatedStateMsdx); + addTArgument(rho, epsilon); + } + + @Override + public String opName() { + return "ada_delta_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java new file mode 100644 index 000000000..e2304bdfb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdaGradUpdater extends DynamicCustomOp { + + public AdaGradUpdater() { + // + } + + public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double epsilon) { + this(gradients, state, gradients, state, lr, epsilon); + } + + public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double epsilon) { + addInputArgument(gradients, state); + addOutputArgument(updates, updatedState); + addTArgument(lr, epsilon); + } + + @Override + public String opName() { + return "ada_grad_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java new file mode 100644 index 000000000..483078335 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdaMaxUpdater extends DynamicCustomOp { + + public AdaMaxUpdater() { + // + } + + public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateU, stateM); + addOutputArgument(updates, updatedStateU, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "ada_max_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java new file mode 100644 index 000000000..1ab34ae52 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdamUpdater extends DynamicCustomOp { + + public AdamUpdater() { + // + } + + public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateU, stateM); + addOutputArgument(updates, updatedStateU, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "adam_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java new file mode 100644 index 000000000..35af113ad --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AmsGradUpdater extends DynamicCustomOp { + + public AmsGradUpdater() { + // + } + + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration); + } + + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateV, stateM, stateH); + addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "ams_grad_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java new file mode 100644 index 000000000..ad4f374b7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class NadamUpdater extends DynamicCustomOp { + + public NadamUpdater() { + // + } + + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateV, stateM); + addOutputArgument(updates, updatedStateV, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "nadam_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java new file mode 100644 index 000000000..a277f750f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class NesterovsUpdater extends DynamicCustomOp { + + public NesterovsUpdater() { + // + } + + public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double momentum) { + this(gradients, state, gradients, state, lr, momentum); + } + + public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double momentum) { + addInputArgument(gradients, state); + addOutputArgument(updates, updatedState); + addTArgument(lr, momentum); + } + + @Override + public String opName() { + return "nesterovs_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java new file mode 100644 index 000000000..aaf734ea8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class RmsPropUpdater extends DynamicCustomOp { + + public RmsPropUpdater() { + // + } + + public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double decay, double epsilon) { + this(gradients, state, gradients, state, lr, decay, epsilon); + } + + public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double decay, double epsilon) { + addInputArgument(gradients, state); + addOutputArgument(updates, updatedState); + addTArgument(lr, decay, epsilon); + } + + @Override + public String opName() { + return "rms_prop_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java new file mode 100644 index 000000000..ef40735a4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class SgdUpdater extends DynamicCustomOp { + + public SgdUpdater() { + // + } + + public SgdUpdater(@NonNull INDArray input, double lr) { + this(input, input, lr); + } + + public SgdUpdater(@NonNull INDArray input, @NonNull INDArray output, double lr) { + addInputArgument(input); + addOutputArgument(output); + addTArgument(lr); + } + + @Override + public String opName() { + return "sgd_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 17bf95031..33260da70 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -10686,6 +10686,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 80d5904a6..47791f865 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -12422,6 +12422,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java index 660b178e4..4a4d6aab6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -15,10 +15,12 @@ ******************************************************************************/ package org.nd4j.linalg.learning; +import lombok.val; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.learning.config.*; @@ -58,14 +60,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val msgu = msg.dup(); + val msdxu = msdx.dup(); UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(g3, msgu, msdxu, rho, epsilon)); + assertEquals(msg, state.get("msg")); assertEquals(msdx, state.get("msdx")); assertEquals(g1, g2); + + assertEquals(msg, msgu); + assertEquals(msdx, msdxu); + assertEquals(g1, g3); } } @@ -85,13 +96,20 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val su = s.dup(); UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(g3, su, lr, epsilon)); + assertEquals(s, state.get("grad")); assertEquals(g1, g2); + + assertEquals(s, su); + assertEquals(g1, g3); } } @@ -118,14 +136,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val mu = m.dup(); + val vu = v.dup(); UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -150,14 +177,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val mu = m.dup(); + val vu = v.dup(); UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(g3, vu, mu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -185,15 +221,26 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val mu = m.dup(); + val vu = v.dup(); + val hu = vH.dup(); UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new AmsGradUpdater(g3, vu, mu, hu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(vH, state.get("V_HAT")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(vH, hu); + assertEquals(g1, g3); } } @@ -219,14 +266,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val vu = v.dup(); + val mu = m.dup(); UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -247,13 +303,18 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val vu = v.dup(); UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum); - u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(g3, vu, lr, momentum)); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -275,13 +336,19 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val gu = g.dup(); UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps); - u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(g3, gu, lr,decay, eps)); assertEquals(g, state.get("G")); assertEquals(g1, g2); + + assertEquals(g, gu); + assertEquals(g1, g3); + } } @@ -294,11 +361,14 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); UpdaterJavaCode.applySgd(g1, lr); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(g3, lr)); u.applyUpdater(g2, i, 0); assertEquals(g1, g2); + assertEquals(g1, g3); } } From 1f3e4c18e14831d2837425c80515c4b5c5346293 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 23 Mar 2020 07:28:54 +0300 Subject: [PATCH 12/17] some structure for ops (#337) Signed-off-by: raver119 --- .../generic/{parity_ops => images}/adjust_contrast.cpp | 0 .../ops/declarable/generic/{parity_ops => images}/adjust_hue.cpp | 0 .../generic/{parity_ops => images}/adjust_saturation.cpp | 0 .../generic/{parity_ops => images}/crop_and_resize.cpp | 0 .../generic/{parity_ops => images}/draw_bounding_boxes.cpp | 0 .../generic/{parity_ops => images}/extract_image_patches.cpp | 0 .../declarable/generic/{parity_ops => images}/image_resize.cpp | 0 .../declarable/generic/{parity_ops => images}/resize_area.cpp | 0 .../declarable/generic/{parity_ops => images}/resize_bicubic.cpp | 0 .../declarable/generic/{parity_ops => images}/resize_linear.cpp | 0 .../generic/{parity_ops => images}/resize_neighbor.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/betaInc.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/cholesky.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/cross.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/diag.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/diagPart.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/digamma.cpp | 0 .../ops/declarable/generic/{transforms => linalg}/eye.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/lgamma.cpp | 0 .../ops/declarable/generic/{transforms => linalg}/log1p.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/lstsq.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/lup.cpp | 0 .../declarable/generic/{parity_ops => linalg}/matrixDiagPart.cpp | 0 .../declarable/generic/{parity_ops => linalg}/matrixSetDiag.cpp | 0 .../generic/{parity_ops => linalg}/matrix_band_part.cpp | 0 .../generic/{parity_ops => linalg}/matrix_determinant.cpp | 0 .../declarable/generic/{parity_ops => linalg}/matrix_diag.cpp | 0 .../declarable/generic/{parity_ops => linalg}/matrix_inverse.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/moments.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/polygamma.cpp | 0 .../include/ops/declarable/generic/{parity_ops => linalg}/qr.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/solve.cpp | 0 .../generic/{parity_ops => linalg}/sufficient_statistics.cpp | 0 .../ops/declarable/generic/{transforms => linalg}/trace.cpp | 0 .../ops/declarable/generic/{transforms => linalg}/tri.cpp | 0 .../generic/{parity_ops => linalg}/triangular_solve.cpp | 0 .../ops/declarable/generic/{transforms => linalg}/triu.cpp | 0 .../ops/declarable/generic/{parity_ops => linalg}/zeta.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/crelu.cpp | 0 .../include/ops/declarable/generic/{ => nn}/activations/cube.cpp | 0 .../include/ops/declarable/generic/{ => nn}/activations/elu.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/hardsigmoid.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/hardtanh.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/identity.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/identity_n.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/lrelu.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/prelu.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/rationaltanh.cpp | 0 .../declarable/generic/{ => nn}/activations/rectifiedtanh.cpp | 0 .../include/ops/declarable/generic/{ => nn}/activations/relu.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/relu6.cpp | 0 .../include/ops/declarable/generic/{ => nn}/activations/selu.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/sigmoid.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/softplus.cpp | 0 .../ops/declarable/generic/{ => nn}/activations/softsign.cpp | 0 .../include/ops/declarable/generic/{ => nn}/activations/tanh.cpp | 0 .../declarable/generic/{ => nn}/activations/thresholdedrelu.cpp | 0 .../ops/declarable/generic/{parity_ops => nn}/bias_add.cpp | 0 .../declarable/generic/{parity_ops => nn}/embedding_lookup.cpp | 0 .../ops/declarable/generic/{transforms => nn}/layer_norm.cpp | 0 .../generic/{ => nn}/recurrent/dynamicBidirectionalRNN.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/dynamicRNN.cpp | 0 .../include/ops/declarable/generic/{ => nn}/recurrent/gru.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/gruCell.cpp | 0 .../include/ops/declarable/generic/{ => nn}/recurrent/lstm.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/lstmBlock.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/lstmBlockCell.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/lstmCell.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/lstmLayer.cpp | 0 .../include/ops/declarable/generic/{ => nn}/recurrent/sru.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/sruCell.cpp | 0 .../generic/{ => nn}/recurrent/staticBidirectionalRNN.cpp | 0 .../ops/declarable/generic/{ => nn}/recurrent/staticRNN.cpp | 0 .../ops/declarable/generic/{parity_ops => nn}/xw_plus_b.cpp | 0 .../ops/declarable/generic/{parity_ops => random}/dropout.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/argmax.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/argmin.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/norm.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/reduceMean.cpp | 0 .../declarable/generic/{parity_ops => reduce}/reduceStDev.cpp | 0 .../declarable/generic/{parity_ops => reduce}/reduceVariance.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/reduce_dot.cpp | 0 .../generic/{parity_ops => reduce}/reduce_logsumexp.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/reduce_max.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/reduce_min.cpp | 0 .../declarable/generic/{parity_ops => reduce}/reduce_norm1.cpp | 0 .../declarable/generic/{parity_ops => reduce}/reduce_norm2.cpp | 0 .../generic/{parity_ops => reduce}/reduce_norm_max.cpp | 0 .../declarable/generic/{parity_ops => reduce}/reduce_prod.cpp | 0 .../declarable/generic/{parity_ops => reduce}/reduce_sqnorm.cpp | 0 .../ops/declarable/generic/{parity_ops => reduce}/reduce_sum.cpp | 0 .../ops/declarable/generic/{transforms => shape}/flatten.cpp | 0 .../ops/declarable/generic/{parity_ops => shape}/rank.cpp | 0 .../ops/declarable/generic/{parity_ops => shape}/size.cpp | 1 - .../include/ops/declarable/generic/{shape => tensor}/create.cpp | 0 .../ops/declarable/generic/{parity_ops => tensor}/fill.cpp | 0 .../ops/declarable/generic/{parity_ops => tensor}/fill_as.cpp | 0 .../ops/declarable/generic/{parity_ops => tensor}/lin_space.cpp | 0 .../ops/declarable/generic/{parity_ops => tensor}/ones_as.cpp | 0 .../ops/declarable/generic/{parity_ops => tensor}/range.cpp | 0 .../declarable/generic/{parity_ops => tensor}/strided_slice.cpp | 0 .../ops/declarable/generic/{parity_ops => tensor}/zeros_as.cpp | 0 .../generic/{parity_ops => transforms}/batch_to_space.cpp | 0 .../generic/{parity_ops => transforms}/batch_to_space_nd.cpp | 0 .../generic/{parity_ops => transforms}/depth_to_space.cpp | 0 .../generic/{parity_ops => transforms}/dynamic_parititon.cpp | 0 .../generic/{parity_ops => transforms}/dynamic_stitch.cpp | 0 .../generic/{parity_ops => transforms}/parallelStack.cpp | 0 .../generic/{parity_ops => transforms}/scatter_add.cpp | 0 .../generic/{parity_ops => transforms}/scatter_div.cpp | 0 .../generic/{parity_ops => transforms}/scatter_max.cpp | 0 .../generic/{parity_ops => transforms}/scatter_min.cpp | 0 .../generic/{parity_ops => transforms}/scatter_mul.cpp | 0 .../declarable/generic/{parity_ops => transforms}/scatter_nd.cpp | 0 .../generic/{parity_ops => transforms}/scatter_nd_add.cpp | 0 .../generic/{parity_ops => transforms}/scatter_nd_sub.cpp | 0 .../generic/{parity_ops => transforms}/scatter_nd_update.cpp | 0 .../generic/{parity_ops => transforms}/scatter_sub.cpp | 0 .../generic/{parity_ops => transforms}/scatter_upd.cpp | 0 .../ops/declarable/generic/{parity_ops => transforms}/slice.cpp | 0 .../generic/{parity_ops => transforms}/space_to_batch.cpp | 0 .../generic/{parity_ops => transforms}/space_to_batch_nd.cpp | 0 .../generic/{parity_ops => transforms}/space_to_depth.cpp | 0 .../ops/declarable/generic/{parity_ops => transforms}/split.cpp | 0 .../declarable/generic/{parity_ops => transforms}/split_v.cpp | 0 .../ops/declarable/generic/{parity_ops => transforms}/stack.cpp | 0 .../ops/declarable/generic/{parity_ops => transforms}/tear.cpp | 0 .../declarable/generic/{parity_ops => transforms}/unstack.cpp | 0 128 files changed, 1 deletion(-) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/adjust_contrast.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/adjust_hue.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/adjust_saturation.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/crop_and_resize.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/draw_bounding_boxes.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/extract_image_patches.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/image_resize.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/resize_area.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/resize_bicubic.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/resize_linear.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => images}/resize_neighbor.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/betaInc.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/cholesky.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/cross.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/diag.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/diagPart.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/digamma.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => linalg}/eye.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/lgamma.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => linalg}/log1p.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/lstsq.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/lup.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/matrixDiagPart.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/matrixSetDiag.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/matrix_band_part.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/matrix_determinant.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/matrix_diag.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/matrix_inverse.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/moments.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/polygamma.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/qr.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/solve.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/sufficient_statistics.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => linalg}/trace.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => linalg}/tri.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/triangular_solve.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => linalg}/triu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => linalg}/zeta.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/crelu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/cube.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/elu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/hardsigmoid.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/hardtanh.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/identity.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/identity_n.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/lrelu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/prelu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/rationaltanh.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/rectifiedtanh.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/relu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/relu6.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/selu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/sigmoid.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/softplus.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/softsign.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/tanh.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/activations/thresholdedrelu.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => nn}/bias_add.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => nn}/embedding_lookup.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => nn}/layer_norm.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/dynamicBidirectionalRNN.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/dynamicRNN.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/gru.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/gruCell.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/lstm.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/lstmBlock.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/lstmBlockCell.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/lstmCell.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/lstmLayer.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/sru.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/sruCell.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/staticBidirectionalRNN.cpp (100%) rename libnd4j/include/ops/declarable/generic/{ => nn}/recurrent/staticRNN.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => nn}/xw_plus_b.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => random}/dropout.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/argmax.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/argmin.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/norm.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduceMean.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduceStDev.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduceVariance.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_dot.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_logsumexp.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_max.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_min.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_norm1.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_norm2.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_norm_max.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_prod.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_sqnorm.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => reduce}/reduce_sum.cpp (100%) rename libnd4j/include/ops/declarable/generic/{transforms => shape}/flatten.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => shape}/rank.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => shape}/size.cpp (97%) rename libnd4j/include/ops/declarable/generic/{shape => tensor}/create.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/fill.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/fill_as.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/lin_space.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/ones_as.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/range.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/strided_slice.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => tensor}/zeros_as.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/batch_to_space.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/batch_to_space_nd.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/depth_to_space.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/dynamic_parititon.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/dynamic_stitch.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/parallelStack.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_add.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_div.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_max.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_min.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_mul.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_nd.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_nd_add.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_nd_sub.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_nd_update.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_sub.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/scatter_upd.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/slice.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/space_to_batch.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/space_to_batch_nd.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/space_to_depth.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/split.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/split_v.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/stack.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/tear.cpp (100%) rename libnd4j/include/ops/declarable/generic/{parity_ops => transforms}/unstack.cpp (100%) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp rename to libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp rename to libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp rename to libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp b/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/crop_and_resize.cpp rename to libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/draw_bounding_boxes.cpp rename to libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/extract_image_patches.cpp b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/extract_image_patches.cpp rename to libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/image_resize.cpp b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/image_resize.cpp rename to libnd4j/include/ops/declarable/generic/images/image_resize.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp rename to libnd4j/include/ops/declarable/generic/images/resize_area.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp rename to libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp rename to libnd4j/include/ops/declarable/generic/images/resize_linear.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp rename to libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp b/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp rename to libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/cholesky.cpp b/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/cholesky.cpp rename to libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/cross.cpp b/libnd4j/include/ops/declarable/generic/linalg/cross.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/cross.cpp rename to libnd4j/include/ops/declarable/generic/linalg/cross.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/diag.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/diag.cpp rename to libnd4j/include/ops/declarable/generic/linalg/diag.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/diagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/diagPart.cpp rename to libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp rename to libnd4j/include/ops/declarable/generic/linalg/digamma.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/eye.cpp b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/eye.cpp rename to libnd4j/include/ops/declarable/generic/linalg/eye.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp rename to libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp b/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/log1p.cpp rename to libnd4j/include/ops/declarable/generic/linalg/log1p.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lstsq.cpp b/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/lstsq.cpp rename to libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp rename to libnd4j/include/ops/declarable/generic/linalg/lup.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrixDiagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/matrixDiagPart.cpp rename to libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp rename to libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_band_part.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/matrix_band_part.cpp rename to libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_determinant.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/matrix_determinant.cpp rename to libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp rename to libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_inverse.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/matrix_inverse.cpp rename to libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/moments.cpp rename to libnd4j/include/ops/declarable/generic/linalg/moments.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp rename to libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/qr.cpp b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/qr.cpp rename to libnd4j/include/ops/declarable/generic/linalg/qr.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp b/libnd4j/include/ops/declarable/generic/linalg/solve.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp rename to libnd4j/include/ops/declarable/generic/linalg/solve.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp b/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/sufficient_statistics.cpp rename to libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/trace.cpp b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/trace.cpp rename to libnd4j/include/ops/declarable/generic/linalg/trace.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/tri.cpp b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/tri.cpp rename to libnd4j/include/ops/declarable/generic/linalg/tri.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp b/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/triangular_solve.cpp rename to libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/triu.cpp b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/triu.cpp rename to libnd4j/include/ops/declarable/generic/linalg/triu.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zeta.cpp b/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/zeta.cpp rename to libnd4j/include/ops/declarable/generic/linalg/zeta.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/crelu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/cube.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/cube.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/elu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/elu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/hardsigmoid.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/hardtanh.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/identity.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/identity.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/identity_n.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/identity_n.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/lrelu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/prelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/prelu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/rationaltanh.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/rectifiedtanh.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/relu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/relu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/relu6.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/relu6.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/selu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/selu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/sigmoid.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/softplus.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/softplus.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/softsign.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/softsign.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/tanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/tanh.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp diff --git a/libnd4j/include/ops/declarable/generic/activations/thresholdedrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/activations/thresholdedrelu.cpp rename to libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp rename to libnd4j/include/ops/declarable/generic/nn/bias_add.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp rename to libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp rename to libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/dynamicBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/dynamicBidirectionalRNN.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/dynamicRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/dynamicRNN.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/gru.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstm.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/lstm.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmBlock.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/lstmBlock.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmBlockCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/lstmBlockCell.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/lstmCell.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/sru.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/sruCell.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/staticBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/staticBidirectionalRNN.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp diff --git a/libnd4j/include/ops/declarable/generic/recurrent/staticRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/recurrent/staticRNN.cpp rename to libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/xw_plus_b.cpp rename to libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dropout.cpp b/libnd4j/include/ops/declarable/generic/random/dropout.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/dropout.cpp rename to libnd4j/include/ops/declarable/generic/random/dropout.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp rename to libnd4j/include/ops/declarable/generic/reduce/argmax.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp rename to libnd4j/include/ops/declarable/generic/reduce/argmin.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/norm.cpp rename to libnd4j/include/ops/declarable/generic/reduce/norm.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduceMean.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduceStDev.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduceVariance.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_dot.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_dot.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_logsumexp.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_min.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm1.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm2.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_norm_max.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_prod.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_sqnorm.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/reduce_sum.cpp rename to libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp diff --git a/libnd4j/include/ops/declarable/generic/transforms/flatten.cpp b/libnd4j/include/ops/declarable/generic/shape/flatten.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/transforms/flatten.cpp rename to libnd4j/include/ops/declarable/generic/shape/flatten.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp b/libnd4j/include/ops/declarable/generic/shape/rank.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp rename to libnd4j/include/ops/declarable/generic/shape/rank.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/size.cpp b/libnd4j/include/ops/declarable/generic/shape/size.cpp similarity index 97% rename from libnd4j/include/ops/declarable/generic/parity_ops/size.cpp rename to libnd4j/include/ops/declarable/generic/shape/size.cpp index d31e782c6..fd76548cb 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/size.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size.cpp @@ -32,7 +32,6 @@ namespace sd { REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); output->p(0, input->lengthOf()); - output->syncToDevice(); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/shape/create.cpp b/libnd4j/include/ops/declarable/generic/tensor/create.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/shape/create.cpp rename to libnd4j/include/ops/declarable/generic/tensor/create.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fill.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/fill.cpp rename to libnd4j/include/ops/declarable/generic/tensor/fill.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fill_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/fill_as.cpp rename to libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp rename to libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp rename to libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp b/libnd4j/include/ops/declarable/generic/tensor/range.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/range.cpp rename to libnd4j/include/ops/declarable/generic/tensor/range.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp rename to libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp rename to libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space.cpp rename to libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/batch_to_space_nd.cpp rename to libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/depth_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/depth_to_space.cpp rename to libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_parititon.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/dynamic_parititon.cpp rename to libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp rename to libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/parallelStack.cpp b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/parallelStack.cpp rename to libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_add.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_sub.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_update.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp rename to libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/slice.cpp b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/slice.cpp rename to libnd4j/include/ops/declarable/generic/transforms/slice.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch.cpp rename to libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/space_to_batch_nd.cpp rename to libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/space_to_depth.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/space_to_depth.cpp rename to libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp b/libnd4j/include/ops/declarable/generic/transforms/split.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/split.cpp rename to libnd4j/include/ops/declarable/generic/transforms/split.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/split_v.cpp b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/split_v.cpp rename to libnd4j/include/ops/declarable/generic/transforms/split_v.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stack.cpp b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/stack.cpp rename to libnd4j/include/ops/declarable/generic/transforms/stack.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/tear.cpp rename to libnd4j/include/ops/declarable/generic/transforms/tear.cpp diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp b/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp similarity index 100% rename from libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp rename to libnd4j/include/ops/declarable/generic/transforms/unstack.cpp From 55ec207eb8407531c97271b9f2c168ca12c6bf44 Mon Sep 17 00:00:00 2001 From: Oleh Date: Mon, 23 Mar 2020 06:30:26 +0200 Subject: [PATCH 13/17] Split convolutions implementations for compilation speed up (#339) * libnd4j first step of convolutions implementation split Signed-off-by: Oleg * libnd4j convolutions cuda implementation split Signed-off-by: Oleg * libnd4j code clean up Signed-off-by: Oleg --- .../declarable/helpers/cpu/convolutions.cpp | 1864 ----------------- .../helpers/cpu/convolutions_col2vol.cpp | 143 ++ .../helpers/cpu/convolutions_conv2d.cpp | 107 + .../helpers/cpu/convolutions_conv2dBP.cpp | 127 ++ .../cpu/convolutions_depthwiseConv2d.cpp | 101 + .../cpu/convolutions_depthwiseConv2dBP.cpp | 120 ++ .../helpers/cpu/convolutions_pooling2d.cpp | 223 ++ .../helpers/cpu/convolutions_pooling2dBP.cpp | 306 +++ .../helpers/cpu/convolutions_pooling3d.cpp | 261 +++ .../helpers/cpu/convolutions_pooling3dBP.cpp | 326 +++ .../helpers/cpu/convolutions_sconv2d.cpp | 73 + .../helpers/cpu/convolutions_upsampling2d.cpp | 80 + .../cpu/convolutions_upsampling2dBP.cpp | 86 + .../helpers/cpu/convolutions_upsampling3d.cpp | 89 + .../cpu/convolutions_upsampling3dBP.cpp | 95 + .../helpers/cpu/convolutions_vol2col.cpp | 147 ++ .../declarable/helpers/cuda/convolutions.cu | 1670 --------------- .../helpers/cuda/convolutions_col2vol.cu | 131 ++ .../helpers/cuda/convolutions_conv2d.cu | 105 + .../helpers/cuda/convolutions_conv2dBP.cu | 125 ++ .../cuda/convolutions_depthwiseConv2d.cu | 101 + .../cuda/convolutions_depthwiseConv2dBP.cu | 120 ++ .../helpers/cuda/convolutions_pooling2d.cu | 342 +++ .../helpers/cuda/convolutions_pooling2dBP.cu | 188 ++ .../helpers/cuda/convolutions_pooling3d.cu | 181 ++ .../helpers/cuda/convolutions_pooling3dBP.cu | 202 ++ .../helpers/cuda/convolutions_sconv2d.cu | 73 + .../helpers/cuda/convolutions_upsampling2d.cu | 97 + .../cuda/convolutions_upsampling2dBP.cu | 103 + .../helpers/cuda/convolutions_upsampling3d.cu | 98 + .../cuda/convolutions_upsampling3dBP.cu | 107 + .../helpers/cuda/convolutions_vol2col.cu | 111 + 32 files changed, 4368 insertions(+), 3534 deletions(-) delete mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp delete mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp deleted file mode 100644 index 4140c2143..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ /dev/null @@ -1,1864 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 -// - -#include -#include -#include -#include -#include -#include -#include - -namespace sd { - namespace ops { - - -////////////////////////////////////////////////////////////////////////// -// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - template - static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* colBuff = columns.bufferAsT(); - T* volBuff = const_cast(volume).bufferAsT(); - - - if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) { - - auto func = PRAGMA_THREADS_FOR_3D { - T *col, *vol; - int volDep, volRow, volCol; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int kDep = start_z; kDep < stop_z; kDep += inc_z) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.); - else { - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *col = *vol; - } - } - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); - - } else { - - auto func = PRAGMA_THREADS_FOR_2D { - T *col, *vol; - int volDep, volRow, volCol; - for (int b = start_x; b < stop_x; b++) { - for (int colD = start_y; colD < stop_y; colD++) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.f); - else { - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *col = *vol; - } - } - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); - //func(0, 0, bS, 1, 0, oD, 1); - } - } - -////////////////////////////////////////////////////////////////////////// -// [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] - template - static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - // initial zeroing of volume content - volume.nullify(); - - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* volBuff = volume.bufferAsT(); - T* colBuff = const_cast(columns).bufferAsT(); - - - if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) { - - auto func = PRAGMA_THREADS_FOR { - T* col, *vol; - int volDep, volRow, volCol; - - for (int b = start; b < stop; b++) { - for (int c = 0; c < iC; c++) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = -pD + kDep * dD + colD * sD; - volRow = -pH + kRow * dH + colH * sH; - volCol = -pW + kCol * dW + colW * sW; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *vol += *col; - } - } - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, bS); - - } else { - - auto func = PRAGMA_THREADS_FOR { - T* col, *vol; - int volDep, volRow, volCol; - - for (int b = start; b < stop; b++) { - for (int colD = 0; colD < oD; colD++) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *vol += *col; - } - } - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, bS); - } - } - - -////////////////////////////////////////////////////////////////////////// - template - static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); - - std::vector permutForOutput; - - if(isNCHW) - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - else - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - - std::vector wAxes; - if(0 == wFormat) - wAxes = {0, 1, 2}; - else if(1 == wFormat) - wAxes = {2, 3, 1}; - else - wAxes = {1, 2, 3}; - - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); - NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - - //----- calculation of output -----// - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - - } - -////////////////////////////////////////////////////////////////////////// - template - static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); - - std::vector gradOaxesForDot; - - if(!isNCHW) { - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - } else { - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - } - - std::vector wPermut, colPermut; - - if(0 == wFormat) { - wPermut = {2, 0, 1, 3}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - else if(1 == wFormat) { - wPermut = {1, 2, 3, 0}; - colPermut = {1, 2, 3, 0, 4, 5}; - } - else { - wPermut = {3, 1, 2, 0}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - - // ----- calculation of gradW ----- // - if(gradW) { - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] - // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); - - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - } - -////////////////////////////////////////////////////////////////////////// - template - static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput, modifWeights; - std::vector outReShape; - - if(!isNCHW) { - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); - - helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - } - -////////////////////////////////////////////////////////////////////////// - template - static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2, modifWeights; - std::vector gradOreShape; - - if(!isNCHW) { - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW - - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - } - -////////////////////////////////////////////////////////////////////////// - template - static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); - - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW - delete outputDepth; - } - } - -////////////////////////////////////////////////////////////////////////// - template - static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const uint dimIH = isNCHW ? 2 : 1; - const uint dimIC = isNCHW ? 1 : 3; - - const uint bS = input.sizeAt(0); - const uint iC = input.sizeAt(dimIC); - const uint oH = output.sizeAt(dimIH); - const uint oW = output.sizeAt(dimIH + 1); - - const Nd4jLong xStride0 = input.stridesOf()[0]; - const Nd4jLong xStride1 = input.stridesOf()[dimIC]; - const Nd4jLong xStride2 = input.stridesOf()[dimIH]; - const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1]; - - const Nd4jLong zStride0 = output.stridesOf()[0]; - const Nd4jLong zStride1 = output.stridesOf()[dimIC]; - const Nd4jLong zStride2 = output.stridesOf()[dimIH]; - const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - uint xCoord2, xCoord3; - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint h = start_z; h < stop_z; h += inc_z) { - for (uint w = 0; w < oW; ++w) { - xCoord2 = h / factorH; - xCoord3 = w / factorW; - - z[b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3] = x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3]; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); - } - -////////////////////////////////////////////////////////////////////////// - template - static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const uint dimID = isNCDHW ? 2 : 1; - const uint dimIC = isNCDHW ? 1 : 4; - - const uint bS = input.sizeAt(0); - const uint iC = input.sizeAt(dimIC); - const uint oD = output.sizeAt(dimID); - const uint oH = output.sizeAt(dimID + 1); - const uint oW = output.sizeAt(dimID + 2); - - const Nd4jLong xStride0 = input.stridesOf()[0]; - const Nd4jLong xStride1 = input.stridesOf()[dimIC]; - const Nd4jLong xStride2 = input.stridesOf()[dimID]; - const Nd4jLong xStride3 = input.stridesOf()[dimID + 1]; - const Nd4jLong xStride4 = input.stridesOf()[dimID + 2]; - - const Nd4jLong zStride0 = output.stridesOf()[0]; - const Nd4jLong zStride1 = output.stridesOf()[dimIC]; - const Nd4jLong zStride2 = output.stridesOf()[dimID]; - const Nd4jLong zStride3 = output.stridesOf()[dimID + 1]; - const Nd4jLong zStride4 = output.stridesOf()[dimID + 2]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - uint xCoord2, xCoord3, xCoord4; - - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint d = start_z; d < stop_z; d += inc_z) { - for (uint h = 0; h < oH; ++h) { - for (uint w = 0; w < oW; ++w) { - - xCoord2 = d / factorD; - xCoord3 = h / factorH; - xCoord4 = w / factorW; - - z[b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4] = x[ - b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3 + - xCoord4 * xStride4]; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } - -////////////////////////////////////////////////////////////////////////// - template - static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - - const T* x = gradO.bufferAsT(); - T* z = gradI.bufferAsT(); - - const uint dimIH = isNCHW ? 2 : 1; - const uint dimIC = isNCHW ? 1 : 3; - - const uint bS = gradI.sizeAt(0); - const uint iC = gradI.sizeAt(dimIC); - const uint iH = gradI.sizeAt(dimIH); - const uint iW = gradI.sizeAt(dimIH + 1); - - const uint factorH = gradO.sizeAt(dimIH) / iH; - const uint factorW = gradO.sizeAt(dimIH + 1) / iW; - - const Nd4jLong xStride0 = gradO.stridesOf()[0]; - const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; - const Nd4jLong xStride2 = gradO.stridesOf()[dimIH]; - const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1]; - - const Nd4jLong zStride0 = gradI.stridesOf()[0]; - const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; - const Nd4jLong zStride2 = gradI.stridesOf()[dimIH]; - const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint h = start_z; h < stop_z; h += inc_z) { - for (uint w = 0; w < iW; ++w) { - - const auto zOffset = b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3; - - z[zOffset] = 0; - - for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) - for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) - z[zOffset] += x[b * xStride0 + c * xStride1 + xh * xStride2 + xw * xStride3]; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); - } - -////////////////////////////////////////////////////////////////////////// - template - static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = gradO.bufferAsT(); - T* z = gradI.bufferAsT(); - - const uint dimID = isNCDHW ? 2 : 1; - const uint dimIC = isNCDHW ? 1 : 4; - - const uint bS = gradI.sizeAt(0); - const uint iC = gradI.sizeAt(dimIC); - const uint iD = gradI.sizeAt(dimID); - const uint iH = gradI.sizeAt(dimID + 1); - const uint iW = gradI.sizeAt(dimID + 2); - - const uint factorD = gradO.sizeAt(dimID) / iD; - const uint factorH = gradO.sizeAt(dimID + 1) / iH; - const uint factorW = gradO.sizeAt(dimID + 2) / iW; - - const Nd4jLong xStride0 = gradO.stridesOf()[0]; - const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; - const Nd4jLong xStride2 = gradO.stridesOf()[dimID]; - const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1]; - const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2]; - - const Nd4jLong zStride0 = gradI.stridesOf()[0]; - const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; - const Nd4jLong zStride2 = gradI.stridesOf()[dimID]; - const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1]; - const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint d = start_z; d < stop_z; d += inc_z) { - for (uint h = 0; h < iH; ++h) { - for (uint w = 0; w < iW; ++w) { - - const auto zOffset = b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4; - - z[zOffset] = 0; - - for (uint xd = d * factorD; xd < d * factorD + factorD; ++xd) - for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) - for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) - z[zOffset] += x[b * xStride0 + c * xStride1 + xd * xStride2 + xh * xStride3 + xw * xStride4]; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); - } - -////////////////////////////////////////////////////////////////////////// - template - static void pooling2d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); - - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iH = input.sizeAt(2); - const int iW = input.sizeAt(3); - const int oC = output.sizeAt(1); - const int oH = output.sizeAt(2); - const int oW = output.sizeAt(3); - - nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const int kProd = kH*kW; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T max = -DataTypeUtils::max(); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T val = pIn[kh + kw]; - if (val > max) - max = val; - } - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += pIn[kh + kw]; - - if (extraParam0 == 0) { //Exclude padding - int a = (hend - hstart) / iStep2 + ((hend - hstart) % iStep2 == 0 ? 0 : 1); - int r = (wend - wstart) / iStep3 + ((wend - wstart) % iStep3 == 0 ? 0 : 1); - sum /= static_cast(a * r); // Accounts for dilation - } else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - sum = sd::math::nd4j_pow(sum, static_cast((T) 1.f) / extraParam0); - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; - } - } - -////////////////////////////////////////////////////////////////////////// - template - static void pooling3d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iD, iH, iW] - // output is [bS, iC, oD, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); - - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iD = input.sizeAt(2); - const int iH = input.sizeAt(3); - const int iW = input.sizeAt(4); - const int oC = output.sizeAt(1); - const int oD = output.sizeAt(2); - const int oH = output.sizeAt(3); - const int oW = output.sizeAt(4); - - nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - const Nd4jLong oStride4 = output.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const int kProd = kD*kH*kW; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = -DataTypeUtils::max(); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T val = pIn[kd + kh + kw]; - if (val > sum) - sum = val; - } - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += pIn[kd + kh + kw]; - - if (extraParam0 == 0) //Exclude padding - sum /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(iStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(iStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(iStep4)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw std::runtime_error("Incorrect poooling3d mode"); - } - } - - -////////////////////////////////////////////////////////////////////////// - template - static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iH, iW] - // gradI [bS, iC, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oH, oW] - - // initial zeroing of gradI - gradI.nullify(); - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); - - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iH = gradI.sizeAt(2); - const int iW = gradI.sizeAt(3); - const int oC = gradO.sizeAt(1); - const int oH = gradO.sizeAt(2); - const int oW = gradO.sizeAt(3); - - nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong gIStride0 = gradI.stridesOf()[0]; - const Nd4jLong gIStride1 = gradI.stridesOf()[1]; - const Nd4jLong gIStride2 = gradI.stridesOf()[2]; - const Nd4jLong gIStride3 = gradI.stridesOf()[3]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const Nd4jLong gIStep2 = dH*gIStride2; - const Nd4jLong gIStep3 = dW*gIStride3; - const int kProd = kH*kW; - - const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - sum = -DataTypeUtils::max(); - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if (sameStrides) { - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T valIn = pIn[kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKH + maxKW] += valO; - } else { - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - T valIn = pIn[kh * iStride2 + kw * iStride3]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } - } - - gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + maxKW * gIStride3] += valO; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * gIStride0 + c * gIStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / - dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / - dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= gIStride2; - hend *= gIStride2; - wstart *= gIStride3; - wend *= gIStride3; - - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if ((int) extraParam0 == 0) //Exclude padding - valO /= static_cast(sd::math::nd4j_ceil( - static_cast(hend - hstart) / static_cast(gIStep2))) * - static_cast(sd::math::nd4j_ceil( - static_cast(wend - wstart) / - static_cast(gIStep3))); //Accounts for dilation - else if ((int) extraParam0 == 1) //Include padding - valO /= kProd; - - for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3) - pgI[kh + kw] += valO; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = sameStrides ? gI + (pIn - in) : gI + b * gIStride0 + c * gIStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / - dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / - dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - sum = static_cast(0.f); - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if (sameStrides) { - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, - ((T) 1. - extraParam0) / extraParam0); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - pgI[kh + kw] += valO * sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f) * - sd::math::nd4j_sgn(pIn[kh + kw]); - } else { - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) - sum += sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh * iStride2 + kw * iStride3]), - extraParam0); - - valO *= sd::math::nd4j_pow(sum, - ((T) 1. - extraParam0) / extraParam0); - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) { - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - const auto inVal = pIn[kh * iStride2 + kw * iStride3]; - pgI[kh * gIStride2 + kw * gIStride3] += valO * - sd::math::nd4j_pow( - sd::math::nd4j_abs( - inVal), - extraParam0 - 1.f) * - sd::math::nd4j_sgn( - inVal); - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw std::runtime_error("Incorrect pooling2dBP mode"); - } - } - -////////////////////////////////////////////////////////////////////////// - template - static void pooling3dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iD, iH, iW] - // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oD, oH, oW] - - // initial zeroing of gradI - gradI.nullify(); - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); - - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iD = gradI.sizeAt(2); - const int iH = gradI.sizeAt(3); - const int iW = gradI.sizeAt(4); - const int oC = gradO.sizeAt(1); - const int oD = gradO.sizeAt(2); - const int oH = gradO.sizeAt(3); - const int oW = gradO.sizeAt(4); - - nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong gIStride0 = gradI.stridesOf()[0]; - const Nd4jLong gIStride1 = gradI.stridesOf()[1]; - const Nd4jLong gIStride2 = gradI.stridesOf()[2]; - const Nd4jLong gIStride3 = gradI.stridesOf()[3]; - const Nd4jLong gIStride4 = gradI.stridesOf()[4]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong oStride4 = gradO.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const Nd4jLong gIStep2 = dD*gIStride2; - const Nd4jLong gIStep3 = dH*gIStride3; - const Nd4jLong gIStep4 = dW*gIStride4; - const int kProd = kD*kH*kW; - - const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - sum = -DataTypeUtils::max(); - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (sameStrides) { - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - maxKD = dstart; - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T valIn = pIn[kd + kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKD + maxKH + maxKW] += valO; - } else { - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - maxKD = dstart; - - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - T valIn = pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - - gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + maxKH * gIStride3 + maxKW * gIStride4] += valO; - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * gIStride0 + c * gIStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= gIStride2; - dend *= gIStride2; - hstart *= gIStride3; - hend *= gIStride3; - wstart *= gIStride4; - wend *= gIStride4; - - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (extraParam0 == 0) //Exclude padding - valO /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(gIStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(gIStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(gIStep4)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - valO /= kProd; - - for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4) - pgI[kd + kh + kw] += valO; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = gI + (pIn - in); - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - sum = static_cast(0.); - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (sameStrides) { - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]),extraParam0 - (T) 1.f) * sd::math::nd4j_sgn(pIn[kd + kh + kw]); - } else { - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); - - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - const auto inVal = pIn[kD * iStride2 + kh * iStride3 + kw * iStride4]; - pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(inVal), extraParam0 - 1.f) * sd::math::nd4j_sgn(inVal); - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; - } - } - - - - - void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - } - void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - } - void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); - } - void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); - } - void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); - } - void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); - } - - - - void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - } - void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - } - void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - } - void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp new file mode 100644 index 000000000..c9cae504a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + +////////////////////////////////////////////////////////////////////////// +// [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] +template +static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + // initial zeroing of volume content + volume.nullify(); + + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; + + T* volBuff = volume.bufferAsT(); + T* colBuff = const_cast(columns).bufferAsT(); + + + if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) { + + auto func = PRAGMA_THREADS_FOR { + T* col, *vol; + int volDep, volRow, volCol; + + for (int b = start; b < stop; b++) { + for (int c = 0; c < iC; c++) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + + volDep = -pD + kDep * dD + colD * sD; + volRow = -pH + kRow * dH + colH * sH; + volCol = -pW + kCol * dW + colW * sW; + + if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { + col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; + vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; + *vol += *col; + } + } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_tad(func, 0, bS); + + } else { + + auto func = PRAGMA_THREADS_FOR { + T* col, *vol; + int volDep, volRow, volCol; + + for (int b = start; b < stop; b++) { + for (int colD = 0; colD < oD; colD++) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { + col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; + vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; + *vol += *col; + } + } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_tad(func, 0, bS); + } + } + +void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp new file mode 100644 index 000000000..45e66651c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp @@ -0,0 +1,107 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); + + std::vector permutForOutput; + + if(isNCHW) + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + else + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + + std::vector wAxes; + if(0 == wFormat) + wAxes = {0, 1, 2}; + else if(1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); + NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); + + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + + //----- assign outTemp to output -----// + if(isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); + + //----- add biases if required -----// + if(bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if(!isNCHW) + delete input; + + } + +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp new file mode 100644 index 000000000..6a01a4a4d --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp @@ -0,0 +1,127 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); + + std::vector gradOaxesForDot; + + if(!isNCHW) { + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + } + + std::vector wPermut, colPermut; + + if(0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + else if(1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } + else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + + // ----- calculation of gradW ----- // + if(gradW) { + auto ctx = block.launchContext(); + helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] + } + + // ----- calculation of gradB ----- // + if(gradB) { + NDArray* gradBR = gradB; + if(gradB->rankOf() == 2) + gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW + if(gradBR != gradB) + delete gradBR; + } + + //----- calculation of gradI -----// + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] + // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); + + helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if(!isNCHW) { + delete input; + delete gradI; + } + } + +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp new file mode 100644 index 000000000..fa86dbd6c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp @@ -0,0 +1,101 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput, modifWeights; + std::vector outReShape; + + if(!isNCHW) { + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } + else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + + if(paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); + + helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + + if(bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if(!isNCHW) + delete input; + } + +void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp new file mode 100644 index 000000000..7c0d933e2 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp @@ -0,0 +1,120 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2, modifWeights; + std::vector gradOreShape; + + if(!isNCHW) { + gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } + else { + gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + + if(paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if(gradB) { + NDArray* gradBR = gradB; + if(gradB->rankOf() == 2) + gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW + + if(gradBR != gradB) + delete gradBR; + } + + //----- calculation of gradI -----// + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if(!isNCHW) { + delete input; + delete gradI; + } + } + +void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); + } + +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp new file mode 100644 index 000000000..26dc4f99e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp @@ -0,0 +1,223 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + +////////////////////////////////////////////////////////////////////////// + template + static void pooling2d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iH = input.sizeAt(2); + const int iW = input.sizeAt(3); + const int oC = output.sizeAt(1); + const int oH = output.sizeAt(2); + const int oW = output.sizeAt(3); + + nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + + const Nd4jLong iStep2 = dH*iStride2; + const Nd4jLong iStep3 = dW*iStride3; + const int kProd = kH*kW; + + if(poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T max = -DataTypeUtils::max(); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T val = pIn[kh + kw]; + if (val > max) + max = val; + } + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } +/*************************************************************************/ + else if(poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += pIn[kh + kw]; + + if (extraParam0 == 0) { //Exclude padding + int a = (hend - hstart) / iStep2 + ((hend - hstart) % iStep2 == 0 ? 0 : 1); + int r = (wend - wstart) / iStep3 + ((wend - wstart) % iStep3 == 0 ? 0 : 1); + sum /= static_cast(a * r); // Accounts for dilation + } else if (extraParam0 == 1) //Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + sum = sd::math::nd4j_pow(sum, static_cast((T) 1.f) / extraParam0); + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + else { + nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw ""; + } + } + + void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp new file mode 100644 index 000000000..03f34bfae --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp @@ -0,0 +1,306 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + +////////////////////////////////////////////////////////////////////////// + template + static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input [bS, iC, iH, iW] + // gradI [bS, iC, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oH, oW] + + // initial zeroing of gradI + gradI.nullify(); + + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); + + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iH = gradI.sizeAt(2); + const int iW = gradI.sizeAt(3); + const int oC = gradO.sizeAt(1); + const int oH = gradO.sizeAt(2); + const int oW = gradO.sizeAt(3); + + nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong gIStride0 = gradI.stridesOf()[0]; + const Nd4jLong gIStride1 = gradI.stridesOf()[1]; + const Nd4jLong gIStride2 = gradI.stridesOf()[2]; + const Nd4jLong gIStride3 = gradI.stridesOf()[3]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong iStep2 = dH*iStride2; + const Nd4jLong iStep3 = dW*iStride3; + const Nd4jLong gIStep2 = dH*gIStride2; + const Nd4jLong gIStep3 = dW*gIStride3; + const int kProd = kH*kW; + + const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3; + + if(poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + sum = -DataTypeUtils::max(); + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; + + if (sameStrides) { + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T valIn = pIn[kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; + } + } + gI[pIn - in + maxKH + maxKW] += valO; + } else { + + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + T valIn = pIn[kh * iStride2 + kw * iStride3]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; + } + } + + gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + maxKW * gIStride3] += valO; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } +/*************************************************************************/ + else if(poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pgI = gI + b * gIStride0 + c * gIStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + hstart *= gIStride2; + hend *= gIStride2; + wstart *= gIStride3; + wend *= gIStride3; + + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; + + if ((int) extraParam0 == 0) //Exclude padding + valO /= static_cast(sd::math::nd4j_ceil( + static_cast(hend - hstart) / static_cast(gIStep2))) * + static_cast(sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast(gIStep3))); //Accounts for dilation + else if ((int) extraParam0 == 1) //Include padding + valO /= kProd; + + for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3) + pgI[kh + kw] += valO; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + pgI = sameStrides ? gI + (pIn - in) : gI + b * gIStride0 + c * gIStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + sum = static_cast(0.f); + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; + + if (sameStrides) { + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + valO *= sd::math::nd4j_pow(sum, + ((T) 1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + pgI[kh + kw] += valO * sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f) * + sd::math::nd4j_sgn(pIn[kh + kw]); + } else { + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh * iStride2 + kw * iStride3]), + extraParam0); + + valO *= sd::math::nd4j_pow(sum, + ((T) 1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) { + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + const auto inVal = pIn[kh * iStride2 + kw * iStride3]; + pgI[kh * gIStride2 + kw * gIStride3] += valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs( + inVal), + extraParam0 - 1.f) * + sd::math::nd4j_sgn( + inVal); + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + else { + nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw std::runtime_error("Incorrect pooling2dBP mode"); + } + } + +void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp new file mode 100644 index 000000000..04d5f993a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp @@ -0,0 +1,261 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// + template + static void pooling3d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input is [bS, iC, iD, iH, iW] + // output is [bS, iC, oD, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kDEff = kD + (kD-1)*(dD-1); + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iD = input.sizeAt(2); + const int iH = input.sizeAt(3); + const int iW = input.sizeAt(4); + const int oC = output.sizeAt(1); + const int oD = output.sizeAt(2); + const int oH = output.sizeAt(3); + const int oW = output.sizeAt(4); + + nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + const Nd4jLong oStride4 = output.stridesOf()[4]; + const Nd4jLong iStep2 = dD*iStride2; + const Nd4jLong iStep3 = dH*iStride3; + const Nd4jLong iStep4 = dW*iStride4; + const int kProd = kD*kH*kW; + + if(poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = -DataTypeUtils::max(); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T val = pIn[kd + kh + kw]; + if (val > sum) + sum = val; + } + + out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } +/*************************************************************************/ + else if(poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += pIn[kd + kh + kw]; + + if (extraParam0 == 0) //Exclude padding + sum /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(iStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(iStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(iStep4)); //Accounts for dilation + else if (extraParam0 == 1) //Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); + + sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); + + out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } + else { + nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw std::runtime_error("Incorrect poooling3d mode"); + } + } + +void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp new file mode 100644 index 000000000..02f6f57ac --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp @@ -0,0 +1,326 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + +////////////////////////////////////////////////////////////////////////// + template + static void pooling3dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input [bS, iC, iD, iH, iW] + // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oD, oH, oW] + + // initial zeroing of gradI + gradI.nullify(); + + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); + + const int kDEff = kD + (kD-1)*(dD-1); + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iD = gradI.sizeAt(2); + const int iH = gradI.sizeAt(3); + const int iW = gradI.sizeAt(4); + const int oC = gradO.sizeAt(1); + const int oD = gradO.sizeAt(2); + const int oH = gradO.sizeAt(3); + const int oW = gradO.sizeAt(4); + + nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong gIStride0 = gradI.stridesOf()[0]; + const Nd4jLong gIStride1 = gradI.stridesOf()[1]; + const Nd4jLong gIStride2 = gradI.stridesOf()[2]; + const Nd4jLong gIStride3 = gradI.stridesOf()[3]; + const Nd4jLong gIStride4 = gradI.stridesOf()[4]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong oStride4 = gradO.stridesOf()[4]; + const Nd4jLong iStep2 = dD*iStride2; + const Nd4jLong iStep3 = dH*iStride3; + const Nd4jLong iStep4 = dW*iStride4; + const Nd4jLong gIStep2 = dD*gIStride2; + const Nd4jLong gIStep3 = dH*gIStride3; + const Nd4jLong gIStep4 = dW*gIStride4; + const int kProd = kD*kH*kW; + + const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4; + + if(poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + sum = -DataTypeUtils::max(); + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; + + if (sameStrides) { + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + maxKD = dstart; + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T valIn = pIn[kd + kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; + } + } + gI[pIn - in + maxKD + maxKH + maxKW] += valO; + } else { + + // we set these to default values + maxKH = hstart; + maxKW = wstart; + maxKD = dstart; + + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + T valIn = pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; + } + } + + gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + maxKH * gIStride3 + maxKW * gIStride4] += valO; + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } +/*************************************************************************/ + else if(poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pgI = gI + b * gIStride0 + c * gIStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= gIStride2; + dend *= gIStride2; + hstart *= gIStride3; + hend *= gIStride3; + wstart *= gIStride4; + wend *= gIStride4; + + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; + + if (extraParam0 == 0) //Exclude padding + valO /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(gIStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(gIStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(gIStep4)); //Accounts for dilation + else if (extraParam0 == 1) //Include padding + valO /= kProd; + + for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4) + pgI[kd + kh + kw] += valO; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + pgI = gI + (pIn - in); + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + sum = static_cast(0.); + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; + + if (sameStrides) { + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); + + valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + pgI[kd + kh + kw] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]),extraParam0 - (T) 1.f) * sd::math::nd4j_sgn(pIn[kd + kh + kw]); + } else { + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]), extraParam0); + + valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + const auto inVal = pIn[kD * iStride2 + kh * iStride3 + kw * iStride4]; + pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(inVal), extraParam0 - 1.f) * sd::math::nd4j_sgn(inVal); + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + else { + nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw ""; + } + } + + void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp new file mode 100644 index 000000000..742f88c3b --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + NDArray* outputDepth = output; + if(weightsPoint) // if pointwise convolution is expected + outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + + // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); + + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW + delete outputDepth; + } + } + +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp new file mode 100644 index 000000000..ffdd5c34b --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { + // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const uint dimIH = isNCHW ? 2 : 1; + const uint dimIC = isNCHW ? 1 : 3; + + const uint bS = input.sizeAt(0); + const uint iC = input.sizeAt(dimIC); + const uint oH = output.sizeAt(dimIH); + const uint oW = output.sizeAt(dimIH + 1); + + const Nd4jLong xStride0 = input.stridesOf()[0]; + const Nd4jLong xStride1 = input.stridesOf()[dimIC]; + const Nd4jLong xStride2 = input.stridesOf()[dimIH]; + const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1]; + + const Nd4jLong zStride0 = output.stridesOf()[0]; + const Nd4jLong zStride1 = output.stridesOf()[dimIC]; + const Nd4jLong zStride2 = output.stridesOf()[dimIH]; + const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + uint xCoord2, xCoord3; + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint h = start_z; h < stop_z; h += inc_z) { + for (uint w = 0; w < oW; ++w) { + xCoord2 = h / factorH; + xCoord3 = w / factorW; + + z[b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3] = x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3]; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); + } + + +void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp new file mode 100644 index 000000000..aba46aabc --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + + const T* x = gradO.bufferAsT(); + T* z = gradI.bufferAsT(); + + const uint dimIH = isNCHW ? 2 : 1; + const uint dimIC = isNCHW ? 1 : 3; + + const uint bS = gradI.sizeAt(0); + const uint iC = gradI.sizeAt(dimIC); + const uint iH = gradI.sizeAt(dimIH); + const uint iW = gradI.sizeAt(dimIH + 1); + + const uint factorH = gradO.sizeAt(dimIH) / iH; + const uint factorW = gradO.sizeAt(dimIH + 1) / iW; + + const Nd4jLong xStride0 = gradO.stridesOf()[0]; + const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; + const Nd4jLong xStride2 = gradO.stridesOf()[dimIH]; + const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1]; + + const Nd4jLong zStride0 = gradI.stridesOf()[0]; + const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; + const Nd4jLong zStride2 = gradI.stridesOf()[dimIH]; + const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint h = start_z; h < stop_z; h += inc_z) { + for (uint w = 0; w < iW; ++w) { + + const auto zOffset = b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3; + + z[zOffset] = 0; + + for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) + z[zOffset] += x[b * xStride0 + c * xStride1 + xh * xStride2 + xw * xStride3]; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); + } + +void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp new file mode 100644 index 000000000..7b86ec5a1 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const uint dimID = isNCDHW ? 2 : 1; + const uint dimIC = isNCDHW ? 1 : 4; + + const uint bS = input.sizeAt(0); + const uint iC = input.sizeAt(dimIC); + const uint oD = output.sizeAt(dimID); + const uint oH = output.sizeAt(dimID + 1); + const uint oW = output.sizeAt(dimID + 2); + + const Nd4jLong xStride0 = input.stridesOf()[0]; + const Nd4jLong xStride1 = input.stridesOf()[dimIC]; + const Nd4jLong xStride2 = input.stridesOf()[dimID]; + const Nd4jLong xStride3 = input.stridesOf()[dimID + 1]; + const Nd4jLong xStride4 = input.stridesOf()[dimID + 2]; + + const Nd4jLong zStride0 = output.stridesOf()[0]; + const Nd4jLong zStride1 = output.stridesOf()[dimIC]; + const Nd4jLong zStride2 = output.stridesOf()[dimID]; + const Nd4jLong zStride3 = output.stridesOf()[dimID + 1]; + const Nd4jLong zStride4 = output.stridesOf()[dimID + 2]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + uint xCoord2, xCoord3, xCoord4; + + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint d = start_z; d < stop_z; d += inc_z) { + for (uint h = 0; h < oH; ++h) { + for (uint w = 0; w < oW; ++w) { + + xCoord2 = d / factorD; + xCoord3 = h / factorH; + xCoord4 = w / factorW; + + z[b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4] = x[ + b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3 + + xCoord4 * xStride4]; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } + + void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp new file mode 100644 index 000000000..93c2746fb --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { + + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = gradO.bufferAsT(); + T* z = gradI.bufferAsT(); + + const uint dimID = isNCDHW ? 2 : 1; + const uint dimIC = isNCDHW ? 1 : 4; + + const uint bS = gradI.sizeAt(0); + const uint iC = gradI.sizeAt(dimIC); + const uint iD = gradI.sizeAt(dimID); + const uint iH = gradI.sizeAt(dimID + 1); + const uint iW = gradI.sizeAt(dimID + 2); + + const uint factorD = gradO.sizeAt(dimID) / iD; + const uint factorH = gradO.sizeAt(dimID + 1) / iH; + const uint factorW = gradO.sizeAt(dimID + 2) / iW; + + const Nd4jLong xStride0 = gradO.stridesOf()[0]; + const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; + const Nd4jLong xStride2 = gradO.stridesOf()[dimID]; + const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1]; + const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2]; + + const Nd4jLong zStride0 = gradI.stridesOf()[0]; + const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; + const Nd4jLong zStride2 = gradI.stridesOf()[dimID]; + const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1]; + const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint d = start_z; d < stop_z; d += inc_z) { + for (uint h = 0; h < iH; ++h) { + for (uint w = 0; w < iW; ++w) { + + const auto zOffset = b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4; + + z[zOffset] = 0; + + for (uint xd = d * factorD; xd < d * factorD + factorD; ++xd) + for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) + z[zOffset] += x[b * xStride0 + c * xStride1 + xd * xStride2 + xh * xStride3 + xw * xStride4]; + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); + } + + + void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); + } + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp new file mode 100644 index 000000000..552dceb6a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp @@ -0,0 +1,147 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 +// + +#include +#include + +namespace sd { + namespace ops { + + +////////////////////////////////////////////////////////////////////////// +// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] +template +static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; + + T* colBuff = columns.bufferAsT(); + T* volBuff = const_cast(volume).bufferAsT(); + + + if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) { + + auto func = PRAGMA_THREADS_FOR_3D { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int kDep = start_z; kDep < stop_z; kDep += inc_z) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; + + if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) + *col = static_cast(0.); + else { + vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; + *col = *vol; + } + } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); + + } else { + + auto func = PRAGMA_THREADS_FOR_2D { + T *col, *vol; + int volDep, volRow, volCol; + for (int b = start_x; b < stop_x; b++) { + for (int colD = start_y; colD < stop_y; colD++) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; + + if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) + *col = static_cast(0.f); + else { + vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; + *col = *vol; + } + } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); + //func(0, 0, bS, 1, 0, oD, 1); + } + } + +void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); +} + +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu deleted file mode 100644 index 47da861ed..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ /dev/null @@ -1,1670 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com) -// - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sd { -namespace ops { - -////////////////////////////////////////////////////////////////////////// -// vol [bS, iC, iD, iH, iW] is convoluted to col [bS, iC, kD, kH, kW, oD, oH, oW] -template -static __global__ void vol2colCuda(const void* volume, const Nd4jLong* volShapeInfo, void* columns, const Nd4jLong* colShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const T* vol = reinterpret_cast(volume); - T* col = reinterpret_cast(columns); - - __shared__ int colRank, volRank; - __shared__ Nd4jLong colLen, iD, iH, iW, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - volRank = 5; - colRank = 8; - - colLen = shape::length(colShapeInfo); - - iD = volShapeInfo[3]; - iH = volShapeInfo[4]; - iW = volShapeInfo[5]; - } - __syncthreads(); - - const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(colInd >= colLen) - return; - - auto coords = sharedMem + threadIdx.x * colRank; - - shape::index2coords(colInd, colShapeInfo, coords); - - // const auto colW = coords[7]; - // const auto colH = coords[6]; - // const auto colD = coords[5]; - // const auto kCol = coords[4]; - // const auto kRow = coords[3]; - // const auto kDep = coords[2]; - // const auto c = coords[1]; - // const auto b = coords[0]; - - const auto colOffset = shape::getOffset(colShapeInfo, coords); - - coords[2] = -pD + coords[2] * dD + coords[5] * sD; // const auto volDep = (-pD + kDep * dD) + colD * sD; - coords[3] = -pH + coords[3] * dH + coords[6] * sH; // const auto volRow = (-pH + kRow * dH) + colH * sH; - coords[4] = -pW + coords[4] * dW + coords[7] * sW; // const auto volCol = (-pW + kCol * dW) + colW * sW; - - if (static_cast(coords[2]) >= static_cast(iD) || static_cast(coords[3]) >= static_cast(iH) || static_cast(coords[4]) >= static_cast(iW)) - col[colOffset] = static_cast(0.); - else - col[colOffset] = vol[shape::getOffset(volShapeInfo, coords)]; -} - -////////////////////////////////////////////////////////////////////////// -template -static void vol2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* volume, const Nd4jLong* volShapeInfo, - void* columns, const Nd4jLong* colShapeInfo, - const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - vol2colCuda<<>>(volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - PointersManager manager(block.launchContext(), "vol2col"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (col.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&col}, {&vol}); - BUILD_SINGLE_SELECTOR(vol.dataType(), vol2colCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), vol.getSpecialBuffer(), vol.getSpecialShapeInfo(), col.specialBuffer(), col.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&col}, {&vol}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kD, kH, kW, oD, oH, oW] to be de-convoluted to volume [bS, iC, iD, iH, iW] -template -static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShapeInfo, void* volume, const Nd4jLong* volShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const T* col = reinterpret_cast(columns); - T* vol = reinterpret_cast(volume); - - __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; - __shared__ Nd4jLong volLen; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - oD = colShapeInfo[6]; - oH = colShapeInfo[7]; - oW = colShapeInfo[8]; - - kD = dD * (colShapeInfo[3] - 1) + 1; - kH = dH * (colShapeInfo[4] - 1) + 1; - kW = dW * (colShapeInfo[5] - 1) + 1; - - volLen = shape::length(volShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * 8; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { - - shape::index2coords(i, volShapeInfo, coords); - - const auto volOffset = shape::getOffset(volShapeInfo, coords); - - const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; - - const uint imD = coords[2] + pD; - const uint imH = coords[3] + pH; - const uint imW = coords[4] + pW; - - const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; - const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; - const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - - const uint colDend = sd::math::nd4j_min(imD / sD + 1, oD); - const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); - const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); - - T val = 0; - - for(uint colD = colDstart; colD < colDend; ++colD) { - coords[2] = imD - colD * sD; - if(coords[2] % dD != 0) continue; - - for(uint colH = colHstart; colH < colHend; ++colH) { - coords[3] = imH - colH * sH; - if(coords[3] % dH != 0) continue; - - for(uint colW = colWstart; colW < colWend; ++colW) { - coords[4] = imW - colW * sW; - if(coords[4] % dW != 0) continue; - - val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]]; - - } - } - } - - vol[volOffset] = val; - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void col2volCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* columns, const Nd4jLong* colShapeInfo, - void* volume, const Nd4jLong* volShapeInfo, - const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - col2volCuda<<>>(columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - PointersManager manager(block.launchContext(), "col2vol"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; - - NDArray::prepareSpecialUse({&vol}, {&col}); - BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&vol}, {&col}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector permutForOutput; - - if(isNCHW) - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - else - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - - std::vector wAxes; - if(0 == wFormat) - wAxes = {0, 1, 2}; - else if(1 == wFormat) - wAxes = {2, 3, 1}; - else - wAxes = {1, 2, 3}; - - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); - NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - - //----- calculation of output -----// - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); -} - -////////////////////////////////////////////////////////////////////////// -template -static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput, modifWeights; - std::vector outReShape; - - if(!isNCHW) { - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); - - helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); -} - -////////////////////////////////////////////////////////////////////////// -template -static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); - - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW - delete outputDepth; - } -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); -} - -////////////////////////////////////////////////////////////////////////// -template -static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); - - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; - - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; - - length = shape::length(zShapeInfo); - - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if(hstart < 0){ - int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); - hstart += f * dH; - } - if(wstart < 0){ - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if(hend > iH){ - int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); - hend -= f * dH; - } - if(wend > iW){ - int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); - wend -= f * dW; - } - - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); - - Z sum = 0.0f; - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) - for (int w = wstart; w < wend; w += dW) - sum += static_cast(inSlice[h * strideY + w * strideX]); - - int divide_factor = pool_size; //Case 0: exclude padding - if (extraParam0 == 1) //Case 1: include padding - divide_factor = kH * kW; - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sum / static_cast(divide_factor); - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void avgPooling2dCudaLauncher(sd::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); -} - -////////////////////////////////////////////////////////////////////////// -template -static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - __shared__ bool fOrder; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); - - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; - - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; - - length = shape::length(zShapeInfo); - - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if (hstart < 0) { - int f = sd::math::nd4j_ceil((Z) -hstart / (Z) dH); - hstart += f * dH; - } - if (wstart < 0) { - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if (hend > iH) { - int f = sd::math::nd4j_ceil((Z) (hend - iH) / (Z) dH); - hend -= f * dH; - } - if (wend > iW) { - int f = sd::math::nd4j_ceil((Z) (wend - iW) / (Z) dW); - wend -= f * dW; - } - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend - hstart) / (double) dH) * - sd::math::nd4j_ceil((double) (wend - wstart) / (double) dW); - - Z sum = 0.f; - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) - for (int w = wstart; w < wend; w += dW) - sum += sd::math::nd4j_pow(static_cast(sd::math::nd4j_abs(inSlice[h * strideY + w * strideX])), extraParam0); - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sd::math::nd4j_pow(sum, (Z) 1.0f / extraParam0); - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void pnormPooling2dCudaLauncher(sd::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); -} - -////////////////////////////////////////////////////////////////////////// -template -static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - __shared__ bool fOrder; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); - - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; - - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; - - length = shape::length(zShapeInfo); - - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if(hstart < 0){ - int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); - hstart += f * dH; - } - if(wstart < 0){ - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if(hend > iH){ - int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); - hend -= f * dH; - } - if(wend > iW){ - int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); - wend -= f * dW; - } - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); - - Z max = -sd::DataTypeUtils::max(); - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) { - for (int w = wstart; w < wend; w += dW) { - Z v = static_cast(inSlice[h * strideY + w * strideX]); - if (v > max) - max = v; - } - } - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = max; - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void maxPooling2dCudaLauncher(sd::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - - switch (poolingMode) { - - case MAX_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - case AVG_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - case PNORM_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - default: - throw std::runtime_error("Pooling2D: Unknown PoolingType used"); - } - - output.tickWriteDevice(); - input.tickReadDevice(); - - auto result = cudaStreamSynchronize(*block.launchContext()->getCudaStream()); - if (result != 0) - throw cuda_exception::build("Pooling2D failed", result); -} - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x input is [bS, iC, iD, iH, iW] - // z output is [bS, iC, oD, oH, oW] - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = 5; - - kDeff = kD + (kD - 1) * (dD - 1); - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iD = xShapeInfo[3]; - iH = xShapeInfo[4]; - iW = xShapeInfo[5]; - - kProd = kD * kH * kW; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - int dstart = coords[2] * sD - pD; - int hstart = coords[3] * sH - pH; - int wstart = coords[4] * sW - pW; - int dend = dstart + kDeff; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - T max = -DataTypeUtils::max(); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) - max = val; - } - } - } - z[zOffset] = max; - } - break; - - /*** avg ***/ - case 1: { - T sum = static_cast(0.); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += x[shape::getOffset(xShapeInfo, coords)]; - - if (extraParam0 == 0) { //Exclude padding - uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1); - uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1); - uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1); - sum /= static_cast(a * b * c); // /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - } - else if (extraParam0 == 1) //Include padding - sum /= kProd; - - z[zOffset] = sum; - } - break; - - /*** pnorm ***/ - case 2: { - T sum = static_cast(0.); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - z[zOffset] = sum; - } - break; - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void pooling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - PointersManager manager(block.launchContext(), "pooling3d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x: input [bS, iC, iH, iW] - // y: gradO [bS, iC, oH, oW] - // z: gradI [bS, iC, iH, iW] -> gradI is output in this function - - const T* x = reinterpret_cast(vx); - const T* y = reinterpret_cast(vy); - T* z = reinterpret_cast(vz); - - Nd4jLong coord2, coord3; - __shared__ int rank, kHeff, kWeff, iH, iW, kProd; - __shared__ Nd4jLong yLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - rank = 4; - - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iH = xShapeInfo[3]; - iW = xShapeInfo[4]; - - kProd = kH * kW; - } - __syncthreads(); - - const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(yInd >= yLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(yInd, yShapeInfo, coords); - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - int hstart = coords[2] * sH - pH; - int wstart = coords[3] * sW - pW; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - coord2 = hstart; - coord3 = wstart; - - T max = -DataTypeUtils::max(); - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW){ - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) { - max = val; - coord2 = coords[2]; - coord3 = coords[3]; - } - } - } - coords[2] = coord2; - coords[3] = coord3; - auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); - //z[zOffset] += y[yOffset]; - } - break; - - /*** avg ***/ - case 1: { - - T val = y[yOffset]; - - if (extraParam0 == 0) //Exclude padding - val /= sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - val /= kProd; - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); - } - break; - - /*** pnorm ***/ - case 2: { - - T sum = static_cast(0.); - T val = y[yOffset]; - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); - } - } - } - break; - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void pooling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling2dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // initial zeroing of gradI - gradI.nullify(); - - PointersManager manager(block.launchContext(), "pooling2dBP"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x: input [bS, iC, iD, iH, iW] - // y: gradO [bS, iC, oD, oH, oW] - // z: gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - - - const T* x = reinterpret_cast(vx); - const T* y = reinterpret_cast(vy); - T* z = reinterpret_cast(vz); - - Nd4jLong coord2, coord3, coord4; - __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; - __shared__ Nd4jLong yLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - rank = 5; - - kDeff = kD + (kD - 1) * (dD - 1); - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iD = xShapeInfo[3]; - iH = xShapeInfo[4]; - iW = xShapeInfo[5]; - - kProd = kD * kH * kW; - } - __syncthreads(); - - const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(yInd >= yLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(yInd, yShapeInfo, coords); - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - int dstart = coords[2] * sD - pD; - int hstart = coords[3] * sH - pH; - int wstart = coords[4] * sW - pW; - int dend = dstart + kDeff; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - - T max = -DataTypeUtils::max(); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) { - max = val; - coord2 = coords[2]; - coord3 = coords[3]; - coord4 = coords[4]; - } - } - } - } - coords[2] = coord2; - coords[3] = coord3; - coords[4] = coord4; - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], y[yOffset]); - } - break; - - /*** avg ***/ - case 1: { - - T val = y[yOffset]; - - if (extraParam0 == 0) //Exclude padding - val /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - val /= kProd; - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); - } - break; - - /*** pnorm ***/ - case 2: { - - T sum = static_cast(0.); - T val = y[yOffset]; - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); - } - } - } - } - break; - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void pooling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling3dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // initial zeroing of gradI - gradI.nullify(); - - PointersManager manager(block.launchContext(), "pooling3dBP"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector gradOaxesForDot; - - if(!isNCHW) { - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - } else { - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - } - - std::vector wPermut, colPermut; - if(0 == wFormat) { - wPermut = {2, 0, 1, 3}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - else if(1 == wFormat) { - wPermut = {1, 2, 3, 0}; - colPermut = {1, 2, 3, 0, 4, 5}; - } - else { - wPermut = {3, 1, 2, 0}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - - // ----- calculation of gradW ----- // - if(gradW) { - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] - // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); -} - -////////////////////////////////////////////////////////////////////////// -template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2, modifWeights; - std::vector gradOreShape; - - if(!isNCHW) { - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); -} - - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void upsampling2dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, const bool isNCHW) { - - // x has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // z has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, dimIH; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - dimIH = isNCHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 4; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - coords[dimIH] /= factorH; - coords[dimIH + 1] /= factorW; - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - - z[zOffset] = x[xOffset]; -} - -////////////////////////////////////////////////////////////////////////// -template -static void upsampling2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int factorH, const int factorW, const bool isNCHW) { - - upsampling2dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "upsampling2d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void upsampling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - // x has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // z has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, dimID; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - dimID = isNCDHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 5; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - coords[dimID] /= factorD; - coords[dimID + 1] /= factorH; - coords[dimID + 2] /= factorW; - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - - z[zOffset] = x[xOffset]; -} - -////////////////////////////////////////////////////////////////////////// -template -static void upsampling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - upsampling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - PointersManager manager(block.launchContext(), "upsampling3d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW) { - - // x (gradO) has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // z (gradI) has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, dimIH; - __shared__ uint factorH, factorW; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - dimIH = isNCHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 4; - - factorH = xShapeInfo[dimIH + 1] / zShapeInfo[dimIH + 1]; - factorW = xShapeInfo[dimIH + 2] / zShapeInfo[dimIH + 2]; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - z[zOffset] = 0; - - const Nd4jLong zCoord2 = coords[dimIH] * factorH; - const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; - - for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH]) - for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1]) - z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; -} - -////////////////////////////////////////////////////////////////////////// -template -static void upsampling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - upsampling2dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCHW); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "upsampling2d_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); -} - -////////////////////////////////////////////////////////////////////////// -template -__global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW) { - - // x (gradO) has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // z (gradI) has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, dimID; - __shared__ uint factorD, factorH, factorW; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - dimID = isNCDHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 5; - - factorD = xShapeInfo[dimID + 1] / zShapeInfo[dimID + 1]; - factorH = xShapeInfo[dimID + 2] / zShapeInfo[dimID + 2]; - factorW = xShapeInfo[dimID + 3] / zShapeInfo[dimID + 3]; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - z[zOffset] = 0; - - const Nd4jLong zCoord2 = coords[dimID] * factorD; - const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; - const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; - - for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID]) - for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1]) - for(coords[dimID + 2] = zCoord4; coords[dimID + 2] < zCoord4 + factorW; ++coords[dimID + 2]) - z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; -} - -////////////////////////////////////////////////////////////////////////// -template -static void upsampling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCDHW) { - - upsampling3dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCDHW); -} - -////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - - PointersManager manager(block.launchContext(), "upsampling3d_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); -} - - - - - - - - - -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu new file mode 100644 index 000000000..d751c2b1e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu @@ -0,0 +1,131 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +// columns [bS, iC, kD, kH, kW, oD, oH, oW] to be de-convoluted to volume [bS, iC, iD, iH, iW] +template +static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShapeInfo, void* volume, const Nd4jLong* volShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + const T* col = reinterpret_cast(columns); + T* vol = reinterpret_cast(volume); + + __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; + __shared__ Nd4jLong volLen; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + oD = colShapeInfo[6]; + oH = colShapeInfo[7]; + oW = colShapeInfo[8]; + + kD = dD * (colShapeInfo[3] - 1) + 1; + kH = dH * (colShapeInfo[4] - 1) + 1; + kW = dW * (colShapeInfo[5] - 1) + 1; + + volLen = shape::length(volShapeInfo); + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * 8; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { + + shape::index2coords(i, volShapeInfo, coords); + + const auto volOffset = shape::getOffset(volShapeInfo, coords); + + const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; + + const uint imD = coords[2] + pD; + const uint imH = coords[3] + pH; + const uint imW = coords[4] + pW; + + const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; + + const uint colDend = sd::math::nd4j_min(imD / sD + 1, oD); + const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); + + T val = 0; + + for(uint colD = colDstart; colD < colDend; ++colD) { + coords[2] = imD - colD * sD; + if(coords[2] % dD != 0) continue; + + for(uint colH = colHstart; colH < colHend; ++colH) { + coords[3] = imH - colH * sH; + if(coords[3] % dH != 0) continue; + + for(uint colW = colWstart; colW < colWend; ++colW) { + coords[4] = imW - colW * sW; + if(coords[4] % dW != 0) continue; + + val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]]; + + } + } + } + + vol[volOffset] = val; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void col2volCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* columns, const Nd4jLong* colShapeInfo, + void* volume, const Nd4jLong* volShapeInfo, + const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + col2volCuda<<>>(columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + PointersManager manager(block.launchContext(), "col2vol"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; + + NDArray::prepareSpecialUse({&vol}, {&col}); + BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); + NDArray::registerSpecialUse({&vol}, {&col}); + + manager.synchronize(); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu new file mode 100644 index 000000000..494ce4a81 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu @@ -0,0 +1,105 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector permutForOutput; + + if(isNCHW) + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + else + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + + std::vector wAxes; + if(0 == wFormat) + wAxes = {0, 1, 2}; + else if(1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); + NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); + + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + + //----- assign outTemp to output -----// + if(isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); + + //----- add biases if required -----// + if(bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if(!isNCHW) + delete input; + +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu new file mode 100644 index 000000000..dbf4ee390 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu @@ -0,0 +1,125 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector gradOaxesForDot; + + if(!isNCHW) { + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + } + + std::vector wPermut, colPermut; + if(0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + else if(1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } + else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + + // ----- calculation of gradW ----- // + if(gradW) { + auto ctx = block.launchContext(); + helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] + } + + // ----- calculation of gradB ----- // + if(gradB) { + NDArray* gradBR = gradB; + if(gradB->rankOf() == 2) + gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW + if(gradBR != gradB) + delete gradBR; + } + + //----- calculation of gradI -----// + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] + // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + + helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if(!isNCHW) { + delete input; + delete gradI; + } +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu new file mode 100644 index 000000000..bbf5d5892 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu @@ -0,0 +1,101 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput, modifWeights; + std::vector outReShape; + + if(!isNCHW) { + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } + else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + + if(paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); + + helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + + if(bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if(!isNCHW) + delete input; +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu new file mode 100644 index 000000000..b06af6166 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu @@ -0,0 +1,120 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2, modifWeights; + std::vector gradOreShape; + + if(!isNCHW) { + gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } + else { + gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if(0 == wFormat) + modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; + else if(1 == wFormat) + modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; + else + modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; + + if(paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if(gradB) { + NDArray* gradBR = gradB; + if(gradB->rankOf() == 2) + gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW + if(gradBR != gradB) + delete gradBR; + } + + //----- calculation of gradI -----// + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if(!isNCHW) { + delete input; + delete gradI; + } +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu new file mode 100644 index 000000000..eb336cb76 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu @@ -0,0 +1,342 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { + + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH-1)*(dH-1); + kWEff = kW + (kW-1)*(dW-1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if(hstart < 0){ + int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); + hstart += f * dH; + } + if(wstart < 0){ + int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); + wstart += f * dW; + } + if(hend > iH){ + int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); + hend -= f * dH; + } + if(wend > iW){ + int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); + wend -= f * dW; + } + + //Accounts for dilation + int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); + + Z sum = 0.0f; + + const X *inSlice = x + (n * strideB + c * strideC); + + for (int h = hstart; h < hend; h += dH) + for (int w = wstart; w < wend; w += dW) + sum += static_cast(inSlice[h * strideY + w * strideX]); + + int divide_factor = pool_size; //Case 0: exclude padding + if (extraParam0 == 1) //Case 1: include padding + divide_factor = kH * kW; + + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sum / static_cast(divide_factor); + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void avgPooling2dCudaLauncher(sd::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { + avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +} + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { + + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + __shared__ bool fOrder; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH-1)*(dH-1); + kWEff = kW + (kW-1)*(dW-1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z) -hstart / (Z) dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z) (hend - iH) / (Z) dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z) (wend - iW) / (Z) dW); + wend -= f * dW; + } + //Accounts for dilation + int pool_size = sd::math::nd4j_ceil((double) (hend - hstart) / (double) dH) * + sd::math::nd4j_ceil((double) (wend - wstart) / (double) dW); + + Z sum = 0.f; + + const X *inSlice = x + (n * strideB + c * strideC); + + for (int h = hstart; h < hend; h += dH) + for (int w = wstart; w < wend; w += dW) + sum += sd::math::nd4j_pow(static_cast(sd::math::nd4j_abs(inSlice[h * strideY + w * strideX])), extraParam0); + + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sd::math::nd4j_pow(sum, (Z) 1.0f / extraParam0); + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void pnormPooling2dCudaLauncher(sd::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { + pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +} + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { + + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + __shared__ bool fOrder; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH-1)*(dH-1); + kWEff = kW + (kW-1)*(dW-1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if(hstart < 0){ + int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); + hstart += f * dH; + } + if(wstart < 0){ + int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); + wstart += f * dW; + } + if(hend > iH){ + int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); + hend -= f * dH; + } + if(wend > iW){ + int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); + wend -= f * dW; + } + //Accounts for dilation + int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); + + Z max = -sd::DataTypeUtils::max(); + + const X *inSlice = x + (n * strideB + c * strideC); + + for (int h = hstart; h < hend; h += dH) { + for (int w = wstart; w < wend; w += dW) { + Z v = static_cast(inSlice[h * strideY + w * strideX]); + if (v > max) + max = v; + } + } + + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = max; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void maxPooling2dCudaLauncher(sd::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { + maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { + + if(!input.isActualOnDeviceSide()) input.syncToDevice(); + + switch (poolingMode) { + + case MAX_POOL: { + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); + } + break; + case AVG_POOL: { + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); + } + break; + case PNORM_POOL: { + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); + } + break; + default: + throw std::runtime_error("Pooling2D: Unknown PoolingType used"); + } + + output.tickWriteDevice(); + input.tickReadDevice(); + + auto result = cudaStreamSynchronize(*block.launchContext()->getCudaStream()); + if (result != 0) + throw cuda_exception::build("Pooling2D failed", result); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu new file mode 100644 index 000000000..26808ad4c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu @@ -0,0 +1,188 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + + // x: input [bS, iC, iH, iW] + // y: gradO [bS, iC, oH, oW] + // z: gradI [bS, iC, iH, iW] -> gradI is output in this function + + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + Nd4jLong coord2, coord3; + __shared__ int rank, kHeff, kWeff, iH, iW, kProd; + __shared__ Nd4jLong yLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + rank = 4; + + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iH = xShapeInfo[3]; + iW = xShapeInfo[4]; + + kProd = kH * kW; + } + __syncthreads(); + + const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(yInd >= yLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(yInd, yShapeInfo, coords); + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + int hstart = coords[2] * sH - pH; + int wstart = coords[3] * sW - pW; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + + switch (poolingMode) { + + /*** max ***/ + case 0: { + coord2 = hstart; + coord3 = wstart; + + T max = -DataTypeUtils::max(); + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW){ + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) { + max = val; + coord2 = coords[2]; + coord3 = coords[3]; + } + } + } + coords[2] = coord2; + coords[3] = coord3; + auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); + //z[zOffset] += y[yOffset]; + } + break; + + /*** avg ***/ + case 1: { + + T val = y[yOffset]; + + if (extraParam0 == 0) //Exclude padding + val /= sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation + else if (extraParam0 == 1) //Include padding + val /= kProd; + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) + sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); + } + break; + + /*** pnorm ***/ + case 2: { + + T sum = static_cast(0.); + T val = y[yOffset]; + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); + + val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); + } + } + } + break; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void pooling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const int poolingMode, const int extraParam0) { + + pooling2dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + + // initial zeroing of gradI + gradI.nullify(); + + PointersManager manager(block.launchContext(), "pooling2dBP"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); +} + +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu new file mode 100644 index 000000000..93e372a7e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu @@ -0,0 +1,181 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include + +namespace sd { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + + // x input is [bS, iC, iD, iH, iW] + // z output is [bS, iC, oD, oH, oW] + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = 5; + + kDeff = kD + (kD - 1) * (dD - 1); + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iD = xShapeInfo[3]; + iH = xShapeInfo[4]; + iW = xShapeInfo[5]; + + kProd = kD * kH * kW; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(zInd >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + int dstart = coords[2] * sD - pD; + int hstart = coords[3] * sH - pH; + int wstart = coords[4] * sW - pW; + int dend = dstart + kDeff; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + + switch (poolingMode) { + + /*** max ***/ + case 0: { + T max = -DataTypeUtils::max(); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) + max = val; + } + } + } + z[zOffset] = max; + } + break; + + /*** avg ***/ + case 1: { + T sum = static_cast(0.); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += x[shape::getOffset(xShapeInfo, coords)]; + + if (extraParam0 == 0) { //Exclude padding + uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1); + uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1); + uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1); + sum /= static_cast(a * b * c); // /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation + } + else if (extraParam0 == 1) //Include padding + sum /= kProd; + + z[zOffset] = sum; + } + break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); + + sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); + + z[zOffset] = sum; + } + break; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void pooling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + + pooling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + + PointersManager manager(block.launchContext(), "pooling3d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} + + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu new file mode 100644 index 000000000..51b48bc23 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu @@ -0,0 +1,202 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + + // x: input [bS, iC, iD, iH, iW] + // y: gradO [bS, iC, oD, oH, oW] + // z: gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + + + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + Nd4jLong coord2, coord3, coord4; + __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; + __shared__ Nd4jLong yLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + rank = 5; + + kDeff = kD + (kD - 1) * (dD - 1); + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iD = xShapeInfo[3]; + iH = xShapeInfo[4]; + iW = xShapeInfo[5]; + + kProd = kD * kH * kW; + } + __syncthreads(); + + const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(yInd >= yLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(yInd, yShapeInfo, coords); + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + int dstart = coords[2] * sD - pD; + int hstart = coords[3] * sH - pH; + int wstart = coords[4] * sW - pW; + int dend = dstart + kDeff; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend - iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend - iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend - iW + dW - 1) / dW); + + + switch (poolingMode) { + + /*** max ***/ + case 0: { + + T max = -DataTypeUtils::max(); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) { + max = val; + coord2 = coords[2]; + coord3 = coords[3]; + coord4 = coords[4]; + } + } + } + } + coords[2] = coord2; + coords[3] = coord3; + coords[4] = coord4; + sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], y[yOffset]); + } + break; + + /*** avg ***/ + case 1: { + + T val = y[yOffset]; + + if (extraParam0 == 0) //Exclude padding + val /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation + else if (extraParam0 == 1) //Include padding + val /= kProd; + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); + } + break; + + /*** pnorm ***/ + case 2: { + + T sum = static_cast(0.); + T val = y[yOffset]; + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); + + val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); + } + } + } + } + break; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void pooling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + + pooling3dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + + // initial zeroing of gradI + gradI.nullify(); + + PointersManager manager(block.launchContext(), "pooling3dBP"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu new file mode 100644 index 000000000..3a9ed5364 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu @@ -0,0 +1,73 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + NDArray* outputDepth = output; + if(weightsPoint) // if pointwise convolution is expected + outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + + // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); + + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW + delete outputDepth; + } +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu new file mode 100644 index 000000000..be9fab0be --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu @@ -0,0 +1,97 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void upsampling2dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, const bool isNCHW) { + + // x has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // z has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, dimIH; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + dimIH = isNCHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 4; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(zInd >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + coords[dimIH] /= factorH; + coords[dimIH + 1] /= factorW; + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + + z[zOffset] = x[xOffset]; +} + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int factorH, const int factorW, const bool isNCHW) { + + upsampling2dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { + + PointersManager manager(block.launchContext(), "upsampling2d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu new file mode 100644 index 000000000..ce393d279 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu @@ -0,0 +1,103 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW) { + + // x (gradO) has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + // z (gradI) has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, dimIH; + __shared__ uint factorH, factorW; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + dimIH = isNCHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 4; + + factorH = xShapeInfo[dimIH + 1] / zShapeInfo[dimIH + 1]; + factorW = xShapeInfo[dimIH + 2] / zShapeInfo[dimIH + 2]; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(zInd >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + z[zOffset] = 0; + + const Nd4jLong zCoord2 = coords[dimIH] * factorH; + const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; + + for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH]) + for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1]) + z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; +} + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const bool isNCHW) { + + upsampling2dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCHW); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + + PointersManager manager(block.launchContext(), "upsampling2d_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); +} + +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu new file mode 100644 index 000000000..6f15a27d6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void upsampling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + + // x has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // z has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, dimID; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + dimID = isNCDHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 5; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(zInd >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + coords[dimID] /= factorD; + coords[dimID + 1] /= factorH; + coords[dimID + 2] /= factorW; + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + + z[zOffset] = x[xOffset]; +} + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + + upsampling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + + PointersManager manager(block.launchContext(), "upsampling3d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} + +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu new file mode 100644 index 000000000..f9eb56bec --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu @@ -0,0 +1,107 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW) { + + // x (gradO) has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // z (gradI) has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, dimID; + __shared__ uint factorD, factorH, factorW; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + dimID = isNCDHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 5; + + factorD = xShapeInfo[dimID + 1] / zShapeInfo[dimID + 1]; + factorH = xShapeInfo[dimID + 2] / zShapeInfo[dimID + 2]; + factorW = xShapeInfo[dimID + 3] / zShapeInfo[dimID + 3]; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(zInd >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + z[zOffset] = 0; + + const Nd4jLong zCoord2 = coords[dimID] * factorD; + const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; + const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; + + for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID]) + for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1]) + for(coords[dimID + 2] = zCoord4; coords[dimID + 2] < zCoord4 + factorW; ++coords[dimID + 2]) + z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; +} + +////////////////////////////////////////////////////////////////////////// +template +static void upsampling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const bool isNCDHW) { + + upsampling3dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCDHW); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { + + PointersManager manager(block.launchContext(), "upsampling3d_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); +} + + +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu new file mode 100644 index 000000000..ebe0ec26e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace sd { +namespace ops { + +////////////////////////////////////////////////////////////////////////// +// vol [bS, iC, iD, iH, iW] is convoluted to col [bS, iC, kD, kH, kW, oD, oH, oW] +template +static __global__ void vol2colCuda(const void* volume, const Nd4jLong* volShapeInfo, void* columns, const Nd4jLong* colShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + const T* vol = reinterpret_cast(volume); + T* col = reinterpret_cast(columns); + + __shared__ int colRank, volRank; + __shared__ Nd4jLong colLen, iD, iH, iW, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + volRank = 5; + colRank = 8; + + colLen = shape::length(colShapeInfo); + + iD = volShapeInfo[3]; + iH = volShapeInfo[4]; + iW = volShapeInfo[5]; + } + __syncthreads(); + + const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(colInd >= colLen) + return; + + auto coords = sharedMem + threadIdx.x * colRank; + + shape::index2coords(colInd, colShapeInfo, coords); + + // const auto colW = coords[7]; + // const auto colH = coords[6]; + // const auto colD = coords[5]; + // const auto kCol = coords[4]; + // const auto kRow = coords[3]; + // const auto kDep = coords[2]; + // const auto c = coords[1]; + // const auto b = coords[0]; + + const auto colOffset = shape::getOffset(colShapeInfo, coords); + + coords[2] = -pD + coords[2] * dD + coords[5] * sD; // const auto volDep = (-pD + kDep * dD) + colD * sD; + coords[3] = -pH + coords[3] * dH + coords[6] * sH; // const auto volRow = (-pH + kRow * dH) + colH * sH; + coords[4] = -pW + coords[4] * dW + coords[7] * sW; // const auto volCol = (-pW + kCol * dW) + colW * sW; + + if (static_cast(coords[2]) >= static_cast(iD) || static_cast(coords[3]) >= static_cast(iH) || static_cast(coords[4]) >= static_cast(iW)) + col[colOffset] = static_cast(0.); + else + col[colOffset] = vol[shape::getOffset(volShapeInfo, coords)]; +} + +////////////////////////////////////////////////////////////////////////// +template +static void vol2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* volume, const Nd4jLong* volShapeInfo, + void* columns, const Nd4jLong* colShapeInfo, + const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + vol2colCuda<<>>(volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); +} + +////////////////////////////////////////////////////////////////////////// +void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + + PointersManager manager(block.launchContext(), "vol2col"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (col.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&col}, {&vol}); + BUILD_SINGLE_SELECTOR(vol.dataType(), vol2colCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), vol.getSpecialBuffer(), vol.getSpecialShapeInfo(), col.specialBuffer(), col.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); + NDArray::registerSpecialUse({&col}, {&vol}); + + manager.synchronize(); +} + +} +} \ No newline at end of file From 226f0672bc0e5d0fa7d0397483d91b623b224ccf Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 23 Mar 2020 17:02:03 +0300 Subject: [PATCH 14/17] size op fixed Signed-off-by: raver119 --- libnd4j/include/ops/declarable/generic/shape/size.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/libnd4j/include/ops/declarable/generic/shape/size.cpp b/libnd4j/include/ops/declarable/generic/shape/size.cpp index fd76548cb..d31e782c6 100644 --- a/libnd4j/include/ops/declarable/generic/shape/size.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size.cpp @@ -32,6 +32,7 @@ namespace sd { REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); output->p(0, input->lengthOf()); + output->syncToDevice(); return Status::OK(); } From 838c3ddb5a30891044d671c03c1eabba7b2f2f0c Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 24 Mar 2020 12:05:17 +1100 Subject: [PATCH 15/17] Timeouts and temp ignore for logged issue - #8802 (#342) Signed-off-by: Alex Black --- .../java/org/deeplearning4j/AssertTestsExtendBaseClass.java | 5 +++++ .../src/test/java/org/nd4j/AssertTestsExtendBaseClass.java | 5 +++++ .../org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java | 2 ++ 3 files changed, 12 insertions(+) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java index 20d2967bb..5f0567094 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -37,6 +37,11 @@ import static org.junit.Assert.assertEquals; @Slf4j public class AssertTestsExtendBaseClass extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) private static final Set> exclusions = new HashSet<>(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java index 5d8a70725..658a63579 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java @@ -42,6 +42,11 @@ import static org.junit.Assert.assertEquals; @Slf4j public class AssertTestsExtendBaseClass extends BaseND4JTest { + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) private static final Set> exclusions = new HashSet<>(Arrays.asList( TFGraphTestAllSameDiff.class, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java index 7addd5098..004aac209 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java @@ -1,6 +1,7 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -82,6 +83,7 @@ public class SameDiffMultiThreadTests extends BaseND4JTest { } @Test + @Ignore //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802 public void testMobilenet() throws Exception { TFGraphTestZooModels.currentTestDir = testDir.newFolder(); File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt"); From 3cbba495183f3af1a7d437f6f7ac6f5b0c77ed0b Mon Sep 17 00:00:00 2001 From: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Date: Tue, 24 Mar 2020 03:55:47 +0200 Subject: [PATCH 16/17] Bugfix failing builds (#341) * Fix interpreter for libnd4j tests and drop test script * Remove mingw when specifying javacpp.platform, add new profile that triggers when javacpp.platform is windows-x86_64 * Update android 32 bit toolchain for x86 * Try triple instead of -target * Change to -target * Update 32 bit arm * Change android bin path * Update arm 32 bit build again Co-authored-by: Adam Gibson <1144306+agibsonccc@users.noreply.github.com> --- libnd4j/buildnativeoperations.sh | 5 ++-- libnd4j/cmake/android-arm.cmake | 4 +-- libnd4j/cmake/android-x86.cmake | 2 +- libnd4j/pom.xml | 2 +- nd4j/compile-android.sh | 1 - .../nd4j-backend-impls/nd4j-native/pom.xml | 27 ++++++++++++++++++- 6 files changed, 33 insertions(+), 8 deletions(-) delete mode 100644 nd4j/compile-android.sh diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index af9154866..6a1f93dbb 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -220,8 +220,9 @@ case "$OS" in setandroid_defaults - - export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" + # Note here for android 32 bit prefix on the binutils is different + # See https://developer.android.com/ndk/guides/other_build_systems + export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi/prebuilt/$KERNEL/" export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" export ANDROID_CC="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/bin/clang" export ANDROID_ROOT="$ANDROID_NDK/platforms/android-$ANDROID_VERSION/arch-arm/" diff --git a/libnd4j/cmake/android-arm.cmake b/libnd4j/cmake/android-arm.cmake index 9d150b070..427bc6a34 100644 --- a/libnd4j/cmake/android-arm.cmake +++ b/libnd4j/cmake/android-arm.cmake @@ -1,7 +1,7 @@ # CMake toolchain to build for Android 5.0 or newer. Sample usage: # set(CMAKE_SYSTEM_NAME Android) -set(CMAKE_ANDROID_ARCH_ABI arm64-v8a) +set(CMAKE_ANDROID_ARCH_ABI armeabi-v7a) set(CMAKE_ANDROID_NDK "$ENV{ANDROID_NDK}") set(CMAKE_ANDROID_STL_TYPE c++_shared) set(CMAKE_SYSTEM_VERSION "$ENV{ANDROID_VERSION}") @@ -18,5 +18,5 @@ endif (WIN32) -add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target aarch64-none-linux-android -march=armv8-a) +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target armv7a-linux-androideabi) diff --git a/libnd4j/cmake/android-x86.cmake b/libnd4j/cmake/android-x86.cmake index 7c3297b74..7290b0b8d 100644 --- a/libnd4j/cmake/android-x86.cmake +++ b/libnd4j/cmake/android-x86.cmake @@ -18,5 +18,5 @@ endif (WIN32) -add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target x86-none-linux-android) +add_definitions(-D__ANDROID_API__=$ENV{ANDROID_VERSION} -DANDROID -fPIC -ffunction-sections -funwind-tables -fstack-protector-strong -target i686-linux-android) diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 8db086245..d682da24c 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -184,7 +184,7 @@ ${libnd4j.test.skip} ${basedir}/tests_cpu - sh + bash run_tests.sh --chip ${libnd4j.chip} diff --git a/nd4j/compile-android.sh b/nd4j/compile-android.sh deleted file mode 100644 index da7fa1799..000000000 --- a/nd4j/compile-android.sh +++ /dev/null @@ -1 +0,0 @@ -mvn clean install -Djavacpp.platform=android-arm64 -Dmaven.test.skip=true -Djavacpp.platform.compiler=$ANDROID_NDK/toolchains/llvm/prebuilt/windows-x86_64/bin/clang++ diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index 466bc6264..4026de17e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -293,6 +293,32 @@ windows + + !javacpp.platform + + + + + + org.bytedeco + javacpp + + ${javacpp.platform}-mingw + + + + +
+ + mingw-windows-platform + + + windows + + + javacpp.platform + windows-x86_64 + @@ -306,7 +332,6 @@ - libnd4j-assembly From 4e8f3a025faea346bda2e43d18767100c3e59753 Mon Sep 17 00:00:00 2001 From: Shams Ul Azeem Date: Tue, 24 Mar 2020 13:11:57 +0500 Subject: [PATCH 17/17] Fixing python object for obtaining scalars (#330) * Fixing python object for obtaining scalars Signed-off-by: shams * Fix variable name for stridePtr Signed-off-by: shams * Fix variable name for stridePtr Signed-off-by: shams Co-authored-by: Alex Black --- .../java/org/datavec/python/PythonObject.java | 26 +++++----- .../datavec/python/ScalarAndArrayTest.java | 48 +++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java index 0408e3a59..4a6a617d5 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -77,7 +77,7 @@ public class PythonObject { long address = bp.address(); long size = bp.capacity(); - NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.BYTE).build(); + NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.INT8).build(); nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject; } @@ -320,20 +320,23 @@ public class PythonObject { public NumpyArray toNumpy() throws PythonException{ PyObject np = PyImport_ImportModule("numpy"); PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); - if (PyObject_IsInstance(nativePythonObject, ndarray) == 0){ + if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){ throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); } Py_DecRef(ndarray); Py_DecRef(np); + Pointer objPtr = new Pointer(nativePythonObject); PyArrayObject npArr = new PyArrayObject(objPtr); Pointer ptr = PyArray_DATA(npArr); - SizeTPointer shapePtr = PyArray_SHAPE(npArr); long[] shape = new long[PyArray_NDIM(npArr)]; - shapePtr.get(shape, 0, shape.length); - SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + SizeTPointer shapePtr = PyArray_SHAPE(npArr); + if (shapePtr != null) + shapePtr.get(shape, 0, shape.length); long[] strides = new long[shape.length]; - stridesPtr.get(strides, 0, strides.length); + SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + if (stridesPtr != null) + stridesPtr.get(strides, 0, strides.length); int npdtype = PyArray_TYPE(npArr); DataType dtype; @@ -345,28 +348,27 @@ public class PythonObject { case NPY_SHORT: dtype = DataType.SHORT; break; case NPY_INT: - dtype = DataType.INT; break; + dtype = DataType.INT32; break; case NPY_LONG: dtype = DataType.LONG; break; case NPY_UINT: dtype = DataType.UINT32; break; case NPY_BYTE: - dtype = DataType.BYTE; break; + dtype = DataType.INT8; break; case NPY_UBYTE: - dtype = DataType.UBYTE; break; + dtype = DataType.UINT8; break; case NPY_BOOL: dtype = DataType.BOOL; break; case NPY_HALF: - dtype = DataType.HALF; break; + dtype = DataType.FLOAT16; break; case NPY_LONGLONG: dtype = DataType.INT64; break; case NPY_USHORT: dtype = DataType.UINT16; break; case NPY_ULONG: - dtype = DataType.UINT64; break; case NPY_ULONGLONG: dtype = DataType.UINT64; break; - default: + default: throw new PythonException("Unsupported array data type: " + npdtype); } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java b/datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java new file mode 100644 index 000000000..e6b1bf606 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertEquals; + +@RunWith(Parameterized.class) +public class ScalarAndArrayTest { + + @Parameterized.Parameters(name = "{index}: Testing with INDArray={0}") + public static INDArray[] data() { + return new INDArray[]{ + Nd4j.scalar(10), + Nd4j.ones(10, 10, 10, 10) + }; + } + + private INDArray indArray; + + public ScalarAndArrayTest(INDArray indArray) { + this.indArray = indArray; + } + + @Test + public void testINDArray() throws PythonException { + assertEquals(indArray, new PythonObject(indArray).toNumpy().getNd4jArray()); + } +}