From 871f3bb3e6b2ac71ebd45d5e283436a1118ce477 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 5 Nov 2019 17:16:17 +0200 Subject: [PATCH] - add additional condition in svd helper to take into account rounding errors (#31) Signed-off-by: Yurii --- libnd4j/include/helpers/cpu/svd.cpp | 2 +- libnd4j/include/ops/declarable/helpers/cpu/svd.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 382a3a125..38d3b9ff4 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -469,7 +469,7 @@ void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra useBisection = true; if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) useBisection = true; - if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev)) + if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) useBisection = true; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index 03fa44131..35615287b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -466,7 +466,7 @@ void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra useBisection = true; if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) useBisection = true; - if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev)) + if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) useBisection = true; }