Hi,

both Math.fma() variants are annotated with @IntrinsicCandidate. This means 
that, in compiled machine code, the methods get replaced by a platform specific 
optimised instruction sequence, possibly just a single hardware instruction.

While your proposal seems to be more efficient than the current Java code, it 
suffers from some drawbacks:
* The methods are far less readable (and maintainable) than the current simple 
Java implementations.
* Being more convoluted, they require extensive review time to make sure they 
are 100% correct.
* They only benefit interpreted code, that is, code that is rarely executed. In 
normal circumstances, hot code does not execute the Java version.
All in all, given the combination of the above, I'm not sure that a review 
effort would benefit the general public.

Also, your claim about improved efficiency should be substantiated by JMH 
results (https://github.com/openjdk/jmh) to convince this audience.

Greetings
Raffaello


________________________________________
From: core-libs-dev <[email protected]> on behalf of John Platts 
<[email protected]>
Sent: Sunday, December 21, 2025 06:12
To: [email protected]
Subject: Improving performance of the default implementation of 
java.lang.Math.fma

There is a much better way to implement java.lang.Math.fma(double, double, 
double) than the current default implementation in 
src/java.base/share/classes/java/lang/Math.java.

Here is a much better implementation of java.lang.Math.fma (and this has been 
shown to be much faster than (new BigDecimal(a)).multiply(new 
BigDecimal(b)).add(new BigDecimal(c)).doubleValue() if a, b, and c are all 
finite floating-point values through benchmarking):
    private static double scalbFiniteF64Sum(double hi, double lo, int exp) {
        // Both hi and lo should be finite, and |hi| >= |lo| should be true if
        // hi != 0.0

        final double sum0 = hi + lo;
        final double scaleResult0 = Math.scalb(sum0, exp);
        if (!(Math.abs(scaleResult0) <= Double.MIN_NORMAL) || lo == 0.0) {
            // If |scaleResult0| > Double.MIN_NORMAL or lo == 0.0, then
            // scaleResult0 will be correctly rounded.

            // Return scaleResult0 in this case.

            return scaleResult0;
        }

        // scaleResult1 is equal to scaleResult0 with its ULP bit cleared.
        // This ensures that the final result will be correctly rounded if
        // the ULP bit of scaleResult0 is set.
        final double scaleResult1 = Double.longBitsToDouble(
                Double.doubleToRawLongBits(scaleResult0) & (-2L));

        // scaleErr is equal to the error of scaling sum0 by 2**exp, scaled back
        // up by 2**-exp. scaleResult1 is used to compute scaleErr to ensure
        // that the value of the ULP bit of scaleResult0, scaled up by 2**-exp,
        // is included in scaleErr to ensure that the final result is correctly
        // rounded.
        final double scaleErr = sum0 - Math.scalb(scaleResult1, -exp);

        // err0 is equal to the error of hi + lo
        final double err0 = (hi - sum0) + lo;

        // Compute scaleErr + err0 using the Fast2Sum algorithm as
        // |scaleErr| >= |err0| should be true if |scaleErr| > 0
        final double sum1 = scaleErr + err0;
        final double err1 = (scaleErr - sum1) + err0;

        final long sum1Bits = Double.doubleToRawLongBits(sum1);
        final long err1Bits = Double.doubleToRawLongBits(err1);
        final long sum1IsInexactInSignBit = err1Bits
                ^ (err1Bits + 0x7FFF_FFFF_FFFF_FFFFL);

        // roundedToOddSum1 is equal to the rounded to odd sum of scaleErr and
        // err0
        final double roundedToOddSum1 = Double.longBitsToDouble((sum1Bits
                + (((sum1Bits ^ err1Bits) & sum1IsInexactInSignBit) >> 63))
                | (sum1IsInexactInSignBit >>> 63));

        // scaleResult2 is equal to the correctly rounded value of
        // (hi + lo) * 2**exp
        final double scaleResult2 = scaleResult1
                + Math.scalb(roundedToOddSum1, exp);

        return scaleResult2;
    }

    public static double fma(double a, double b, double c) {
        final long aBits = Double.doubleToRawLongBits(a);
        final long bBits = Double.doubleToRawLongBits(b);
        final long cBits = Double.doubleToRawLongBits(c);

        final int aBiasedExp = (int) (aBits >>> 52) & 0x7FF;
        final int bBiasedExp = (int) (bBits >>> 52) & 0x7FF;
        final int cBiasedExp = (int) (cBits >>> 52) & 0x7FF;

        if (((aBits ^ (aBits - 1)) | (bBits ^ (bBits - 1))
                | (cBits ^ (cBits - 1)) | (0x7FE - aBiasedExp)
                | (0x7FE - bBiasedExp)) < 0) {
            // If at least one of a, b, or c are zero or at least one of a or b
            // is non-finite, the result will be equal to a * b + c.
            return a * b + c;
        } else if (cBiasedExp == 0x7FF) {
            // If a and b are both nonzero finite numbers and c is a non-finite
            // value, simply return c since the exact result of a * b + c will
            // be equal to c in this case.
            return c;
        }

        // a, b, and c are all nonzero finite values at this point

        // Normalize a and b to normal floating-point numbers

        final int aIsDenormalMask = (aBiasedExp - 1) >> 31;
        final int bIsDenormalMask = (bBiasedExp - 1) >> 31;

        final long aNormalizeAdjBits = (aBits & 0x8000_0000_0000_0000L)
                | (aIsDenormalMask & 0x0350_0000_0000_0000L);
        final long bNormalizeAdjBits = (bBits & 0x8000_0000_0000_0000L)
                | (bIsDenormalMask & 0x0350_0000_0000_0000L);

        final double normalizedA =
                Double.longBitsToDouble(aBits | aNormalizeAdjBits)
                - Double.longBitsToDouble(aNormalizeAdjBits);
        final double normalizedB =
                Double.longBitsToDouble(bBits | bNormalizeAdjBits)
                - Double.longBitsToDouble(bNormalizeAdjBits);

        // If a is a denormal number, normalizedA is equal to a * 2**52.
        // Otherwise, if a is already a normal number, normalizedA is equal to
        // a.

        // If b is a denormal number, normalizedB is equal to b * 2**52.
        // Otherwise, if b is already a normal number, normalizedB is equal to
        // b.

        final long normalizedABits = Double.doubleToRawLongBits(normalizedA);
        final long normalizedBBits = Double.doubleToRawLongBits(normalizedB);

        final int normalizedABiasedExp = (int) (normalizedABits >>> 52) & 0x7FF;
        final int normalizedBBiasedExp = (int) (normalizedBBits >>> 52) & 0x7FF;

        // minPBiasedExp is the smallest possible biased exponent (with an
        // exponent bias of 1023) of the exact value of a * b.
        final int minPBiasedExp = normalizedABiasedExp + normalizedBBiasedExp
                + (aIsDenormalMask & -52) + (bIsDenormalMask & -52) + -1023;

        if (minPBiasedExp >= 2048) {
            // If minPBiasedExp >= 2048, then a * b + c is known to overflow to
            // infinity. Return a * b + c in this case.
            return a * b + c;
        }

        // Normalize c to a normal floating-point number

        final int cIsDenormalMask = (cBiasedExp - 1) >> 31;

        final long cNormalizeAdjBits = (cBits & 0x8000_0000_0000_0000L)
                | (cIsDenormalMask & 0x0350_0000_0000_0000L);

        final double normalizedC =
                Double.longBitsToDouble(cBits | cNormalizeAdjBits)
                - Double.longBitsToDouble(cNormalizeAdjBits);

        // If c is a denormal number, normalizedC is equal to c * 2**52.
        // Otherwise, if c is already a normal number, normalizedC is equal to
        // c.

        final long normalizedCBits = Double.doubleToRawLongBits(normalizedC);
        final int normalizedCBiasedExp = (int) (normalizedCBits >>> 52) & 0x7FF;

        final int adjCBiasedExp = normalizedCBiasedExp
                + (cIsDenormalMask & -52);

        // adjCBiasedExp is equal to floor(log2(|c|)) + 1023, even if c is a
        // denormal number

        final int expDiff = adjCBiasedExp - minPBiasedExp;
        if (expDiff >= 55) {
            // If expDiff is greater than or equal to 55, the exact value of
            // |a * b| is less than 0.5 * ulp(c). The correctly rounded result
            // of a * b + c is known to be equal to c in this case. Return c in
            // this case as the exact value of |a * b| is too small to affect
            // the correctly rounded result of a * b + c.

            return c;
        }

        // aMant is equal to the mantissa of a, with 1 <= |a| < 2
        // bMant is equal to the mantissa of b, with 1 <= |b| < 2
        final double aMant =
                Double.longBitsToDouble(((normalizedABits | 
0x3FF0_0000_0000_0000L)
                        & 0xBFFF_FFFF_FFFF_FFFFL));
        final double bMant =
                Double.longBitsToDouble(((normalizedBBits | 
0x3FF0_0000_0000_0000L)
                        & 0xBFFF_FFFF_FFFF_FFFFL));

        // Split aMant and bMant using Veltkamp-Dekker splitting
        final double aMantGamma = aMant * 134217729.0;
        final double bMantGamma = bMant * 134217729.0;

        final double aMantHi = aMantGamma + (aMant - aMantGamma);
        final double bMantHi = bMantGamma + (bMant - bMantGamma);

        final double aMantLo = aMant - aMantHi;
        final double bMantLo = bMant - bMantHi;

        final double pHi = aMant * bMant;
        final double pLo = ((aMantHi * bMantHi - pHi)
                + (aMantHi * bMantLo + aMantLo * bMantHi)) + aMantLo * bMantLo;

        if (minPBiasedExp >= 1 && pLo == 0.0) {
            // If minPBiasedExp >= 1 and pLo == 0.0 are both true, then a * b is
            // known to be exact and normal. Return a * b + c in this case.

            return a * b + c;
        }

        final int resultScaleUpExp = minPBiasedExp - 1023;

        if (expDiff <= -106) {
            // If expDiff <= -106 is true, then |a * b| >= 2**-968 must be true
            // since 2**-1074 <= |c| < 2**-105 * |a * b|.

            if (pLo == 0.0) {
                // If expDiff <= -106 and pLo == 0.0 are both true, then |c|
                // is too small to affect the correctly rounded result.

                // Return a * b in this case since a * b is known to be either
                // an exact normal finite number or infinity and as |c| is too
                // small to affect the result of a * b + c.
                return a * b;
            } else {
                // If expDiff <= -106 and pLo != 0.0 are both true, then
                // 0 < |c * 2**-resultScaleUpExp| < 0.5 * ulp(pLo) is known
                // to be true.

                // In this case, the rounded to odd sum of pLo and
                // c * 2**-resultScaleUpExp can be computed by decrementing
                // pLoBits by 1 if pLoBits and cBits have different signs
                // rounded to odd sum of pLo and c * 2**-resultScaleUpExp
                // can be computed by simply decrementing pLoBits by 1 if pLo
                // and c have different signs followed by setting the LSB bit as
                // in this case.

                final long pLoBits = Double.doubleToRawLongBits(pLo);

                // roundedToOddLoSum is equal to the rounded to odd sum of
                // pLo and c * 2**-resultScaleUpExp
                final double roundedToOddLoSum = Double.longBitsToDouble(
                        (pLoBits + ((pLoBits ^ cBits) >> 63)) | 1L);

                // Return the correctly rounded value of
                // (s1 + roundedToOddLoSum) * 2**resultScaleUpExp
                return scalbFiniteF64Sum(pHi, roundedToOddLoSum,
                        resultScaleUpExp);
            }
        }

        // -105 <= expDiff <= 54 is now true at this point

        final double scaledC =
            Double.longBitsToDouble(((normalizedCBits | 0x3FF0_0000_0000_0000L)
                        & 0xBFFF_FFFF_FFFF_FFFFL) + ((long) expDiff << 52));

        // Compute pLo + scaledC using the 2Sum algorithm

        final double s0 = pLo + scaledC;
        final double v0 = s0 - pLo;
        final double e0 = (pLo - (s0 - v0)) + (scaledC - v0);

        // s0 + e0 == pLo + scaledC

        // Compute pHi + s0 using the 2Sum algorithm

        final double s1 = pHi + s0;
        final double v1 = s1 - pHi;
        final double e1 = (pHi - (s1 - v1)) + (s0 - v1);

        // s1 + e1 + e0 == pHi + pLo + scaledC

        // Compute the rounded to odd sum of e1 and e0 to ensure that the final
        // result is correctly rounded.

        // Fast2Sum is sufficient for computing e1 + e0 as either |e1| == 0.0 or
        // |e1| >= |e0| should be true

        final double s2 = e1 + e0;
        final double e2 = (e1 - s2) + e0;

        // s1 + s2 + e2 == pHi + pLo + scaledC

        final long s2Bits = Double.doubleToRawLongBits(s2);
        final long e2Bits = Double.doubleToRawLongBits(e2);
        final long s2IsInexactInSignBit = e2Bits
                ^ (e2Bits + 0x7FFF_FFFF_FFFF_FFFFL);

        // roundedToOddS2 is equal to the rounded to odd sum of e1 and e0 and
        // is used to ensure that the final result is correctly rounded
        final double roundedToOddS2 = Double.longBitsToDouble(
                (s2Bits + (((s2Bits ^ e2Bits) & s2IsInexactInSignBit) >> 63))
                        | (s2IsInexactInSignBit >>> 63));

        // Return the correctly rounded value of
        // (s1 + roundedToOddS2) * 2**resultScaleUpExp
        return scalbFiniteF64Sum(s1, roundedToOddS2, resultScaleUpExp);
    }

There are a few tricks up the sleeve with the above implementation.

If xBits is the result of Double.doubleToRawLongBits(x):
xBits ^ (xBits - 1) is equal to -1 if and only if x == 0 or x == -0 and is 
non-negative otherwise
xBits ^ (xBits + 0x7FFF_FFFF_FFFF_FFFFL) is negative if and only if x != 0 and 
equal to Long.MAX_VALUE if x == 0

The rounded to odd result of a + b is used in a few places in the above 
implementation of fma to ensure that the final sum is correctly rounded.

Reply via email to