Author: Chaitanya Koparkar
Date: 2025-08-20T14:51:40+01:00
New Revision: f649605bcf5e9455a4a13e51bec8d7fa89bc6b4c

URL: 
https://github.com/llvm/llvm-project/commit/f649605bcf5e9455a4a13e51bec8d7fa89bc6b4c
DIFF: 
https://github.com/llvm/llvm-project/commit/f649605bcf5e9455a4a13e51bec8d7fa89bc6b4c.diff

LOG: [clang] Enable constexpr handling for __builtin_elementwise_fma (#152919)

Fixes https://github.com/llvm/llvm-project/issues/152455.

Added: 
    

Modified: 
    clang/docs/LanguageExtensions.rst
    clang/include/clang/Basic/Builtins.td
    clang/lib/AST/ByteCode/InterpBuiltin.cpp
    clang/lib/AST/ExprConstant.cpp
    clang/test/CodeGen/rounding-math.cpp
    clang/test/Sema/constant-builtins-vector.cpp

Removed: 
    


################################################################################
diff  --git a/clang/docs/LanguageExtensions.rst 
b/clang/docs/LanguageExtensions.rst
index 12ca4cf42f7cc..6a83d12ce3840 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -767,12 +767,12 @@ elementwise to the input.
 
 Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = 
±infinity
 
-The integer elementwise intrinsics, including 
``__builtin_elementwise_popcount``,
+The elementwise intrinsics ``__builtin_elementwise_popcount``,
 ``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
 ``__builtin_elementwise_sub_sat``, ``__builtin_elementwise_max``,
 ``__builtin_elementwise_min``, ``__builtin_elementwise_abs``,
-``__builtin_elementwise_ctlz``, and ``__builtin_elementwise_cttz`` can be
-called in a ``constexpr`` context.
+``__builtin_elementwise_ctlz``, ``__builtin_elementwise_cttz``, and
+``__builtin_elementwise_fma`` can be called in a ``constexpr`` context.
 
 No implicit promotion of integer types takes place. The mixing of integer types
 of 
diff erent sizes and signs is forbidden in binary and ternary builtins.
@@ -4389,7 +4389,7 @@ fall into one of the specified floating-point classes.
 
   if (__builtin_isfpclass(x, 448)) {
      // `x` is positive finite value
-        ...
+         ...
   }
 
 **Description**:

diff  --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index ad340e2ed0eec..332f369a9032f 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
 
 def ElementwiseFma : Builtin {
   let Spellings = ["__builtin_elementwise_fma"];
-  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
   let Prototype = "void(...)";
 }
 

diff  --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp 
b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index fd8c70c392dcb..5de5091178b8f 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -2714,6 +2714,62 @@ static bool interp__builtin_ia32_pmul(InterpState &S, 
CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
+                                            const CallExpr *Call) {
+  assert(Call->getNumArgs() == 3);
+
+  FPOptions FPO = Call->getFPFeaturesInEffect(S.Ctx.getLangOpts());
+  llvm::RoundingMode RM = getRoundingMode(FPO);
+  const QualType Arg1Type = Call->getArg(0)->getType();
+  const QualType Arg2Type = Call->getArg(1)->getType();
+  const QualType Arg3Type = Call->getArg(2)->getType();
+
+  // Non-vector floating point types.
+  if (!Arg1Type->isVectorType()) {
+    assert(!Arg2Type->isVectorType());
+    assert(!Arg3Type->isVectorType());
+
+    const Floating &Z = S.Stk.pop<Floating>();
+    const Floating &Y = S.Stk.pop<Floating>();
+    const Floating &X = S.Stk.pop<Floating>();
+    APFloat F = X.getAPFloat();
+    F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
+    Floating Result = S.allocFloat(X.getSemantics());
+    Result.copy(F);
+    S.Stk.push<Floating>(Result);
+    return true;
+  }
+
+  // Vector type.
+  assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
+         Arg3Type->isVectorType());
+
+  const VectorType *VecT = Arg1Type->castAs<VectorType>();
+  const QualType ElemT = VecT->getElementType();
+  unsigned NumElems = VecT->getNumElements();
+
+  assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
+         ElemT == Arg3Type->castAs<VectorType>()->getElementType());
+  assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
+         NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
+  assert(ElemT->isRealFloatingType());
+
+  const Pointer &VZ = S.Stk.pop<Pointer>();
+  const Pointer &VY = S.Stk.pop<Pointer>();
+  const Pointer &VX = S.Stk.pop<Pointer>();
+  const Pointer &Dst = S.Stk.peek<Pointer>();
+  for (unsigned I = 0; I != NumElems; ++I) {
+    using T = PrimConv<PT_Float>::T;
+    APFloat X = VX.elem<T>(I).getAPFloat();
+    APFloat Y = VY.elem<T>(I).getAPFloat();
+    APFloat Z = VZ.elem<T>(I).getAPFloat();
+    (void)X.fusedMultiplyAdd(Y, Z, RM);
+    Dst.elem<Floating>(I) = Floating(X);
+  }
+  Dst.initializeAllElements();
+  return true;
+}
+
 bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
                       uint32_t BuiltinID) {
   if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -3145,6 +3201,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const 
CallExpr *Call,
   case clang::X86::BI__builtin_ia32_pmuludq128:
   case clang::X86::BI__builtin_ia32_pmuludq256:
     return interp__builtin_ia32_pmul(S, OpPC, Call, BuiltinID);
+  case Builtin::BI__builtin_elementwise_fma:
+    return interp__builtin_elementwise_fma(S, OpPC, Call);
 
   default:
     S.FFDiag(S.Current->getLocation(OpPC),

diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 9c87a88899647..a03e64fcffde2 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11874,6 +11874,28 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
 
     return Success(APValue(ResultElements.data(), ResultElements.size()), E);
   }
+
+  case Builtin::BI__builtin_elementwise_fma: {
+    APValue SourceX, SourceY, SourceZ;
+    if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
+        !EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
+        !EvaluateAsRValue(Info, E->getArg(2), SourceZ))
+      return false;
+
+    unsigned SourceLen = SourceX.getVectorLength();
+    SmallVector<APValue> ResultElements;
+    ResultElements.reserve(SourceLen);
+    llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+      const APFloat &X = SourceX.getVectorElt(EltNum).getFloat();
+      const APFloat &Y = SourceY.getVectorElt(EltNum).getFloat();
+      const APFloat &Z = SourceZ.getVectorElt(EltNum).getFloat();
+      APFloat Result(X);
+      (void)Result.fusedMultiplyAdd(Y, Z, RM);
+      ResultElements.push_back(APValue(Result));
+    }
+    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+  }
   }
 }
 
@@ -16139,6 +16161,21 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
     Result = minimumnum(Result, RHS);
     return true;
   }
+
+  case Builtin::BI__builtin_elementwise_fma: {
+    if (!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() ||
+        !E->getArg(2)->isPRValue()) {
+      return false;
+    }
+    APFloat SourceY(0.), SourceZ(0.);
+    if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+        !EvaluateFloat(E->getArg(1), SourceY, Info) ||
+        !EvaluateFloat(E->getArg(2), SourceZ, Info))
+      return false;
+    llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+    (void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
+    return true;
+  }
   }
 }
 

diff  --git a/clang/test/CodeGen/rounding-math.cpp 
b/clang/test/CodeGen/rounding-math.cpp
index 264031dc9daa9..5c44fd31242c6 100644
--- a/clang/test/CodeGen/rounding-math.cpp
+++ b/clang/test/CodeGen/rounding-math.cpp
@@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F);
 // CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4
 // CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4
 // CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4
+
+void test_builtin_elementwise_fma_round_upward() {
+  #pragma STDC FENV_ACCESS ON
+  #pragma STDC FENV_ROUND FE_UPWARD
+
+  // CHECK: store float 0x4018000100000000, ptr %f1
+  // CHECK: store float 0x4018000100000000, ptr %f2
+  constexpr float f1 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
+  constexpr float f2 = 2.0F * 3.000001F + 0.000001F;
+  static_assert(f1 == f2);
+  static_assert(f1 == 6.00000381F);
+  // CHECK: store double 0x40180000C9539B89, ptr %d1
+  // CHECK: store double 0x40180000C9539B89, ptr %d2
+  constexpr double d1 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
+  constexpr double d2 = 2.0 * 3.000001 + 0.000001;
+  static_assert(d1 == d2);
+  static_assert(d1 == 6.0000030000000004);
+}
+
+void test_builtin_elementwise_fma_round_downward() {
+  #pragma STDC FENV_ACCESS ON
+  #pragma STDC FENV_ROUND FE_DOWNWARD
+
+  // CHECK: store float 0x40180000C0000000, ptr %f3
+  // CHECK: store float 0x40180000C0000000, ptr %f4
+  constexpr float f3 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
+  constexpr float f4 = 2.0F * 3.000001F + 0.000001F;
+  static_assert(f3 == f4);
+  // CHECK: store double 0x40180000C9539B87, ptr %d3
+  // CHECK: store double 0x40180000C9539B87, ptr %d4
+  constexpr double d3 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
+  constexpr double d4 = 2.0 * 3.000001 + 0.000001;
+  static_assert(d3 == d4);
+}
+
+void test_builtin_elementwise_fma_round_nearest() {
+  #pragma STDC FENV_ACCESS ON
+  #pragma STDC FENV_ROUND FE_TONEAREST
+
+  // CHECK: store float 0x40180000C0000000, ptr %f5
+  // CHECK: store float 0x40180000C0000000, ptr %f6
+  constexpr float f5 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
+  constexpr float f6 = 2.0F * 3.000001F + 0.000001F;
+  static_assert(f5 == f6);
+  static_assert(f5 == 6.00000286F);
+  // CHECK: store double 0x40180000C9539B89, ptr %d5
+  // CHECK: store double 0x40180000C9539B89, ptr %d6
+  constexpr double d5 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
+  constexpr double d6 = 2.0 * 3.000001 + 0.000001;
+  static_assert(d5 == d6);
+  static_assert(d5 == 6.0000030000000004);
+}

diff  --git a/clang/test/Sema/constant-builtins-vector.cpp 
b/clang/test/Sema/constant-builtins-vector.cpp
index 7f882f9ee76eb..9c52a2ab20c7e 100644
--- a/clang/test/Sema/constant-builtins-vector.cpp
+++ b/clang/test/Sema/constant-builtins-vector.cpp
@@ -936,3 +936,24 @@ constexpr vector4char ctz1 = 
__builtin_elementwise_cttz((vector4char){1, 0, 3, 4
 // expected-note@-1 {{evaluation of __builtin_elementwise_cttz with a zero 
value is undefined}}
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_cttz((vector4char){8, 0, 127, 0}, (vector4char){9, -1, 9, 
-2})) == (LITTLE_END ? 0xFE00FF03 : 0x03FF00FE));
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_cttz((vector4char){0, 0, 0, 0}, (vector4char){0, 0, 0, 
0})) == 0);
+
+// Non-vector floating point types.
+static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
+static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
+// Vector type.
+constexpr vector4float fmaFloat1 =
+  __builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
+                            (vector4float){2.0, 3.0, 4.0, 5.0},
+                            (vector4float){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaFloat1[0] == 5.0);
+static_assert(fmaFloat1[1] == 10.0);
+static_assert(fmaFloat1[2] == 17.0);
+static_assert(fmaFloat1[3] == 26.0);
+constexpr vector4double fmaDouble1 =
+  __builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
+                            (vector4double){2.0, 3.0, 4.0, 5.0},
+                            (vector4double){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaDouble1[0] == 5.0);
+static_assert(fmaDouble1[1] == 10.0);
+static_assert(fmaDouble1[2] == 17.0);
+static_assert(fmaDouble1[3] == 26.0);


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to