From 7f0c660d8b7ad69bcf4f5f733ea130ac1e1bc619 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 27 Aug 2019 15:05:43 +0300 Subject: [PATCH] [WIP] HGemm (#181) * skip string arrays for device validation Signed-off-by: raver119 * confusion_matrix fix Signed-off-by: raver119 * exclude cublasHGemm from archs < 530 Signed-off-by: raver119 --- libnd4j/include/helpers/cuda_off/MmulHelper.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index dda709545..19e0d5baf 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -228,6 +228,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou float alphaF(alpha), betaF(beta); status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); } +#if __CUDA_ARCH__ >= 530 else if(ABC && aType == DataType::HALF) { float16 alphaH(alpha), betaH(beta); status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); @@ -240,6 +241,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou float alphaF(alpha), betaF(beta); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); } +#endif else { dim3 threadsPerBlock(N, M); dim3 blocksPerGrid(1, 1);