There is a much better way to implement java.lang.Math.fma(double, double,
double) than the current default implementation in
src/java.base/share/classes/java/lang/Math.java.
Here is a much better implementation of java.lang.Math.fma (and this has been
shown to be much faster than (new BigDecimal(a)).multiply(new
BigDecimal(b)).add(new BigDecimal(c)).doubleValue() if a, b, and c are all
finite floating-point values through benchmarking):
private static double scalbFiniteF64Sum(double hi, double lo, int exp) {
// Both hi and lo should be finite, and |hi| >= |lo| should be true if
// hi != 0.0
final double sum0 = hi + lo;
final double scaleResult0 = Math.scalb(sum0, exp);
if (!(Math.abs(scaleResult0) <= Double.MIN_NORMAL) || lo == 0.0) {
// If |scaleResult0| > Double.MIN_NORMAL or lo == 0.0, then
// scaleResult0 will be correctly rounded.
// Return scaleResult0 in this case.
return scaleResult0;
}
// scaleResult1 is equal to scaleResult0 with its ULP bit cleared.
// This ensures that the final result will be correctly rounded if
// the ULP bit of scaleResult0 is set.
final double scaleResult1 = Double.longBitsToDouble(
Double.doubleToRawLongBits(scaleResult0) & (-2L));
// scaleErr is equal to the error of scaling sum0 by 2**exp, scaled back
// up by 2**-exp. scaleResult1 is used to compute scaleErr to ensure
// that the value of the ULP bit of scaleResult0, scaled up by 2**-exp,
// is included in scaleErr to ensure that the final result is correctly
// rounded.
final double scaleErr = sum0 - Math.scalb(scaleResult1, -exp);
// err0 is equal to the error of hi + lo
final double err0 = (hi - sum0) + lo;
// Compute scaleErr + err0 using the Fast2Sum algorithm as
// |scaleErr| >= |err0| should be true if |scaleErr| > 0
final double sum1 = scaleErr + err0;
final double err1 = (scaleErr - sum1) + err0;
final long sum1Bits = Double.doubleToRawLongBits(sum1);
final long err1Bits = Double.doubleToRawLongBits(err1);
final long sum1IsInexactInSignBit = err1Bits
^ (err1Bits + 0x7FFF_FFFF_FFFF_FFFFL);
// roundedToOddSum1 is equal to the rounded to odd sum of scaleErr and
// err0
final double roundedToOddSum1 = Double.longBitsToDouble((sum1Bits
+ (((sum1Bits ^ err1Bits) & sum1IsInexactInSignBit) >> 63))
| (sum1IsInexactInSignBit >>> 63));
// scaleResult2 is equal to the correctly rounded value of
// (hi + lo) * 2**exp
final double scaleResult2 = scaleResult1
+ Math.scalb(roundedToOddSum1, exp);
return scaleResult2;
}
public static double fma(double a, double b, double c) {
final long aBits = Double.doubleToRawLongBits(a);
final long bBits = Double.doubleToRawLongBits(b);
final long cBits = Double.doubleToRawLongBits(c);
final int aBiasedExp = (int) (aBits >>> 52) & 0x7FF;
final int bBiasedExp = (int) (bBits >>> 52) & 0x7FF;
final int cBiasedExp = (int) (cBits >>> 52) & 0x7FF;
if (((aBits ^ (aBits - 1)) | (bBits ^ (bBits - 1))
| (cBits ^ (cBits - 1)) | (0x7FE - aBiasedExp)
| (0x7FE - bBiasedExp)) < 0) {
// If at least one of a, b, or c are zero or at least one of a or b
// is non-finite, the result will be equal to a * b + c.
return a * b + c;
} else if (cBiasedExp == 0x7FF) {
// If a and b are both nonzero finite numbers and c is a non-finite
// value, simply return c since the exact result of a * b + c will
// be equal to c in this case.
return c;
}
// a, b, and c are all nonzero finite values at this point
// Normalize a and b to normal floating-point numbers
final int aIsDenormalMask = (aBiasedExp - 1) >> 31;
final int bIsDenormalMask = (bBiasedExp - 1) >> 31;
final long aNormalizeAdjBits = (aBits & 0x8000_0000_0000_0000L)
| (aIsDenormalMask & 0x0350_0000_0000_0000L);
final long bNormalizeAdjBits = (bBits & 0x8000_0000_0000_0000L)
| (bIsDenormalMask & 0x0350_0000_0000_0000L);
final double normalizedA =
Double.longBitsToDouble(aBits | aNormalizeAdjBits)
- Double.longBitsToDouble(aNormalizeAdjBits);
final double normalizedB =
Double.longBitsToDouble(bBits | bNormalizeAdjBits)
- Double.longBitsToDouble(bNormalizeAdjBits);
// If a is a denormal number, normalizedA is equal to a * 2**52.
// Otherwise, if a is already a normal number, normalizedA is equal to
// a.
// If b is a denormal number, normalizedB is equal to b * 2**52.
// Otherwise, if b is already a normal number, normalizedB is equal to
// b.
final long normalizedABits = Double.doubleToRawLongBits(normalizedA);
final long normalizedBBits = Double.doubleToRawLongBits(normalizedB);
final int normalizedABiasedExp = (int) (normalizedABits >>> 52) & 0x7FF;
final int normalizedBBiasedExp = (int) (normalizedBBits >>> 52) & 0x7FF;
// minPBiasedExp is the smallest possible biased exponent (with an
// exponent bias of 1023) of the exact value of a * b.
final int minPBiasedExp = normalizedABiasedExp + normalizedBBiasedExp
+ (aIsDenormalMask & -52) + (bIsDenormalMask & -52) + -1023;
if (minPBiasedExp >= 2048) {
// If minPBiasedExp >= 2048, then a * b + c is known to overflow to
// infinity. Return a * b + c in this case.
return a * b + c;
}
// Normalize c to a normal floating-point number
final int cIsDenormalMask = (cBiasedExp - 1) >> 31;
final long cNormalizeAdjBits = (cBits & 0x8000_0000_0000_0000L)
| (cIsDenormalMask & 0x0350_0000_0000_0000L);
final double normalizedC =
Double.longBitsToDouble(cBits | cNormalizeAdjBits)
- Double.longBitsToDouble(cNormalizeAdjBits);
// If c is a denormal number, normalizedC is equal to c * 2**52.
// Otherwise, if c is already a normal number, normalizedC is equal to
// c.
final long normalizedCBits = Double.doubleToRawLongBits(normalizedC);
final int normalizedCBiasedExp = (int) (normalizedCBits >>> 52) & 0x7FF;
final int adjCBiasedExp = normalizedCBiasedExp
+ (cIsDenormalMask & -52);
// adjCBiasedExp is equal to floor(log2(|c|)) + 1023, even if c is a
// denormal number
final int expDiff = adjCBiasedExp - minPBiasedExp;
if (expDiff >= 55) {
// If expDiff is greater than or equal to 55, the exact value of
// |a * b| is less than 0.5 * ulp(c). The correctly rounded result
// of a * b + c is known to be equal to c in this case. Return c in
// this case as the exact value of |a * b| is too small to affect
// the correctly rounded result of a * b + c.
return c;
}
// aMant is equal to the mantissa of a, with 1 <= |a| < 2
// bMant is equal to the mantissa of b, with 1 <= |b| < 2
final double aMant =
Double.longBitsToDouble(((normalizedABits |
0x3FF0_0000_0000_0000L)
& 0xBFFF_FFFF_FFFF_FFFFL));
final double bMant =
Double.longBitsToDouble(((normalizedBBits |
0x3FF0_0000_0000_0000L)
& 0xBFFF_FFFF_FFFF_FFFFL));
// Split aMant and bMant using Veltkamp-Dekker splitting
final double aMantGamma = aMant * 134217729.0;
final double bMantGamma = bMant * 134217729.0;
final double aMantHi = aMantGamma + (aMant - aMantGamma);
final double bMantHi = bMantGamma + (bMant - bMantGamma);
final double aMantLo = aMant - aMantHi;
final double bMantLo = bMant - bMantHi;
final double pHi = aMant * bMant;
final double pLo = ((aMantHi * bMantHi - pHi)
+ (aMantHi * bMantLo + aMantLo * bMantHi)) + aMantLo * bMantLo;
if (minPBiasedExp >= 1 && pLo == 0.0) {
// If minPBiasedExp >= 1 and pLo == 0.0 are both true, then a * b is
// known to be exact and normal. Return a * b + c in this case.
return a * b + c;
}
final int resultScaleUpExp = minPBiasedExp - 1023;
if (expDiff <= -106) {
// If expDiff <= -106 is true, then |a * b| >= 2**-968 must be true
// since 2**-1074 <= |c| < 2**-105 * |a * b|.
if (pLo == 0.0) {
// If expDiff <= -106 and pLo == 0.0 are both true, then |c|
// is too small to affect the correctly rounded result.
// Return a * b in this case since a * b is known to be either
// an exact normal finite number or infinity and as |c| is too
// small to affect the result of a * b + c.
return a * b;
} else {
// If expDiff <= -106 and pLo != 0.0 are both true, then
// 0 < |c * 2**-resultScaleUpExp| < 0.5 * ulp(pLo) is known
// to be true.
// In this case, the rounded to odd sum of pLo and
// c * 2**-resultScaleUpExp can be computed by decrementing
// pLoBits by 1 if pLoBits and cBits have different signs
// rounded to odd sum of pLo and c * 2**-resultScaleUpExp
// can be computed by simply decrementing pLoBits by 1 if pLo
// and c have different signs followed by setting the LSB bit as
// in this case.
final long pLoBits = Double.doubleToRawLongBits(pLo);
// roundedToOddLoSum is equal to the rounded to odd sum of
// pLo and c * 2**-resultScaleUpExp
final double roundedToOddLoSum = Double.longBitsToDouble(
(pLoBits + ((pLoBits ^ cBits) >> 63)) | 1L);
// Return the correctly rounded value of
// (s1 + roundedToOddLoSum) * 2**resultScaleUpExp
return scalbFiniteF64Sum(pHi, roundedToOddLoSum,
resultScaleUpExp);
}
}
// -105 <= expDiff <= 54 is now true at this point
final double scaledC =
Double.longBitsToDouble(((normalizedCBits | 0x3FF0_0000_0000_0000L)
& 0xBFFF_FFFF_FFFF_FFFFL) + ((long) expDiff << 52));
// Compute pLo + scaledC using the 2Sum algorithm
final double s0 = pLo + scaledC;
final double v0 = s0 - pLo;
final double e0 = (pLo - (s0 - v0)) + (scaledC - v0);
// s0 + e0 == pLo + scaledC
// Compute pHi + s0 using the 2Sum algorithm
final double s1 = pHi + s0;
final double v1 = s1 - pHi;
final double e1 = (pHi - (s1 - v1)) + (s0 - v1);
// s1 + e1 + e0 == pHi + pLo + scaledC
// Compute the rounded to odd sum of e1 and e0 to ensure that the final
// result is correctly rounded.
// Fast2Sum is sufficient for computing e1 + e0 as either |e1| == 0.0 or
// |e1| >= |e0| should be true
final double s2 = e1 + e0;
final double e2 = (e1 - s2) + e0;
// s1 + s2 + e2 == pHi + pLo + scaledC
final long s2Bits = Double.doubleToRawLongBits(s2);
final long e2Bits = Double.doubleToRawLongBits(e2);
final long s2IsInexactInSignBit = e2Bits
^ (e2Bits + 0x7FFF_FFFF_FFFF_FFFFL);
// roundedToOddS2 is equal to the rounded to odd sum of e1 and e0 and
// is used to ensure that the final result is correctly rounded
final double roundedToOddS2 = Double.longBitsToDouble(
(s2Bits + (((s2Bits ^ e2Bits) & s2IsInexactInSignBit) >> 63))
| (s2IsInexactInSignBit >>> 63));
// Return the correctly rounded value of
// (s1 + roundedToOddS2) * 2**resultScaleUpExp
return scalbFiniteF64Sum(s1, roundedToOddS2, resultScaleUpExp);
}
There are a few tricks up the sleeve with the above implementation.
If xBits is the result of Double.doubleToRawLongBits(x):
xBits ^ (xBits - 1) is equal to -1 if and only if x == 0 or x == -0 and is
non-negative otherwise
xBits ^ (xBits + 0x7FFF_FFFF_FFFF_FFFFL) is negative if and only if x != 0 and
equal to Long.MAX_VALUE if x == 0
The rounded to odd result of a + b is used in a few places in the above
implementation of fma to ensure that the final sum is correctly rounded.