diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index f5b3ab4ed..d0af6c8ed 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -223,6 +223,12 @@ namespace nd4j { return nd4j_sgn(val); } + template + math_def inline Z nd4j_gamma(X a); + + template + math_def inline Z nd4j_lgamma(X x); + //#ifndef __CUDACC__ /* template<> @@ -656,9 +662,56 @@ namespace nd4j { return p_pow(static_cast(val), static_cast(val2)); } + /** + * LogGamma(a) - float point extension of ln(n!) + **/ + template + math_def inline Z nd4j_lgamma(X x) { +// if (x <= X(0.0)) +// { +// std::stringstream os; +// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given."; +// throw std::invalid_argument( os.str() ); +// } + + if (x < X(12.0)) { + return nd4j_log(nd4j_gamma(x)); + } + + // Abramowitz and Stegun 6.1.41 + // Asymptotic series should be good to at least 11 or 12 figures + // For error analysis, see Whittiker and Watson + // A Course in Modern Analysis (1927), page 252 + + static const double c[8] = { + 1.0/12.0, + -1.0/360.0, + 1.0/1260.0, + -1.0/1680.0, + 1.0/1188.0, + -691.0/360360.0, + 1.0/156.0, + -3617.0/122400.0 + }; + + double z = Z(1.0 / Z(x * x)); + double sum = c[7]; + + for (int i = 6; i >= 0; i--) { + sum *= z; + sum += c[i]; + } + + double series = sum / Z(x); + + static const double halfLogTwoPi = 0.91893853320467274178032973640562; + + return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + halfLogTwoPi + series); + } - template + + template math_def inline T nd4j_re(T val1, T val2) { if (val1 == (T) 0.0f && val2 == (T) 0.0f) return (T) 0.0f; @@ -735,7 +788,105 @@ namespace nd4j { template math_def inline Z nd4j_gamma(X a) { - return (Z)std::tgamma(a); +// nd4j_lgamma(a); +// return (Z)std::tgamma(a); + // Split the function domain into three intervals: + // (0, 0.001), [0.001, 12), and (12, infinity) + + /////////////////////////////////////////////////////////////////////////// + // First interval: (0, 0.001) + // + // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... + // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3. + // The relative error over this interval is less than 6e-7. + + const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant + + if (a < X(0.001)) + return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); + + /////////////////////////////////////////////////////////////////////////// + // Second interval: [0.001, 12) + + if (a < X(12.0)) { + // The algorithm directly approximates gamma over (1,2) and uses + // reduction identities to reduce other arguments to this interval. + + double y = (double)a; + int n = 0; + bool argWasLessThanOne = y < 1.0; + + // Add or subtract integers as necessary to bring y into (1,2) + // Will correct for this below + if (argWasLessThanOne) { + y += 1.0; + } + else { + n = static_cast(floor(y)) - 1; // will use n later + y -= n; + } + + // numerator coefficients for approximation over the interval (1,2) + static const double p[] = { + -1.71618513886549492533811E+0, + 2.47656508055759199108314E+1, + -3.79804256470945635097577E+2, + 6.29331155312818442661052E+2, + 8.66966202790413211295064E+2, + -3.14512729688483675254357E+4, + -3.61444134186911729807069E+4, + 6.64561438202405440627855E+4 + }; + + // denominator coefficients for approximation over the interval (1,2) + static const double q[] = { + -3.08402300119738975254353E+1, + 3.15350626979604161529144E+2, + -1.01515636749021914166146E+3, + -3.10777167157231109440444E+3, + 2.25381184209801510330112E+4, + 4.75584627752788110767815E+3, + -1.34659959864969306392456E+5, + -1.15132259675553483497211E+5 + }; + + double num = 0.0; + double den = 1.0; + + + double z = y - 1; + for (auto i = 0; i < 8; i++) { + num = (num + p[i]) * z; + den = den * z + q[i]; + } + double result = num / den + 1.0; + + // Apply correction if argument was not initially in (1,2) + if (argWasLessThanOne) { + // Use identity gamma(z) = gamma(z+1)/z + // The variable "result" now holds gamma of the original y + 1 + // Thus we use y-1 to get back the orginal y. + result /= (y - 1.0); + } + else { + // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) + for (auto i = 0; i < n; i++) + result *= y++; + } + + return Z(result); + } + + /////////////////////////////////////////////////////////////////////////// + // Third interval: [12, infinity) + + if (a > 171.624) { + // Correct answer too large to display. Force +infinity. + return Z(DOUBLE_MAX_VALUE); + //DataTypeUtils::infOrMax(); + } + + return nd4j::math::nd4j_exp(nd4j::math::nd4j_lgamma(a)); } template