llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Iris Shi (el-ev)

<details>
<summary>Changes</summary>

- Followup of #<!-- -->152294

Added floating point support.

---
Full diff: https://github.com/llvm/llvm-project/pull/153563.diff


4 Files Affected:

- (modified) clang/docs/ReleaseNotes.rst (+2-2) 
- (modified) clang/lib/AST/ByteCode/InterpBuiltin.cpp (+61-41) 
- (modified) clang/lib/AST/ExprConstant.cpp (+32-18) 
- (modified) clang/test/Sema/constant-builtins-vector.cpp (+23) 


``````````diff
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index d109518bca3f3..5d623bf73fea0 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -126,8 +126,8 @@ Non-comprehensive list of changes in this release
   This feature is enabled by default but can be disabled by compiling with
   ``-fno-sanitize-annotate-debug-info-traps``.
 
-- ``__builtin_elementwise_max`` and ``__builtin_elementwise_min`` functions 
for integer types can
-  now be used in constant expressions.
+- ``__builtin_elementwise_max`` and ``__builtin_elementwise_min`` functions 
can now be used in 
+  constant expressions.
 
 New Compiler Flags
 ------------------
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp 
b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index ee2d532551583..5375b184ba378 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -2339,20 +2339,13 @@ static bool 
interp__builtin_elementwise_maxmin(InterpState &S, CodePtr OpPC,
   assert(Call->getNumArgs() == 2);
 
   QualType Arg0Type = Call->getArg(0)->getType();
+  QualType Arg1Type = Call->getArg(1)->getType();
 
-  // TODO: Support floating-point types.
-  if (!(Arg0Type->isIntegerType() ||
-        (Arg0Type->isVectorType() &&
-         Arg0Type->castAs<VectorType>()->getElementType()->isIntegerType())))
-    return false;
-
-  if (!Arg0Type->isVectorType()) {
-    assert(!Call->getArg(1)->getType()->isVectorType());
-    APSInt RHS = popToAPSInt(
-        S.Stk, *S.getContext().classify(Call->getArg(1)->getType()));
-    APSInt LHS = popToAPSInt(
-        S.Stk, *S.getContext().classify(Call->getArg(0)->getType()));
-    APInt Result;
+  if (Arg0Type->isIntegerType()) {
+    assert(Arg1Type->isIntegerType());
+    APSInt RHS = popToAPSInt(S.Stk, *S.getContext().classify(Arg1Type));
+    APSInt LHS = popToAPSInt(S.Stk, *S.getContext().classify(Arg0Type));
+    APSInt Result;
     if (BuiltinID == Builtin::BI__builtin_elementwise_max) {
       Result = std::max(LHS, RHS);
     } else if (BuiltinID == Builtin::BI__builtin_elementwise_min) {
@@ -2360,47 +2353,74 @@ static bool 
interp__builtin_elementwise_maxmin(InterpState &S, CodePtr OpPC,
     } else {
       llvm_unreachable("Wrong builtin ID");
     }
+    pushInteger(S, Result, Call->getType());
+    return true;
+  }
 
-    pushInteger(S, APSInt(Result, !LHS.isSigned()), Call->getType());
+  if (Arg0Type->isRealFloatingType()) {
+    assert(Arg1Type->isRealFloatingType());
+    APFloat RHS = S.Stk.pop<Floating>().getAPFloat();
+    APFloat LHS = S.Stk.pop<Floating>().getAPFloat();
+    Floating Result = S.allocFloat(RHS.getSemantics());
+    if (BuiltinID == Builtin::BI__builtin_elementwise_max) {
+      Result.copy(maxnum(LHS, RHS));
+    } else if (BuiltinID == Builtin::BI__builtin_elementwise_min) {
+      Result.copy(minnum(LHS, RHS));
+    } else {
+      llvm_unreachable("Wrong builtin ID");
+    }
+    S.Stk.push<Floating>(Result);
     return true;
   }
 
   // Vector case.
-  assert(Call->getArg(0)->getType()->isVectorType() &&
-         Call->getArg(1)->getType()->isVectorType());
-  const auto *VT = Call->getArg(0)->getType()->castAs<VectorType>();
-  assert(VT->getElementType() ==
-         Call->getArg(1)->getType()->castAs<VectorType>()->getElementType());
-  assert(VT->getNumElements() ==
-         Call->getArg(1)->getType()->castAs<VectorType>()->getNumElements());
-  assert(VT->getElementType()->isIntegralOrEnumerationType());
+  assert(Arg0Type->isVectorType() && Arg1Type->isVectorType());
+
+  const auto *VT = Arg0Type->castAs<VectorType>();
+  QualType ElemType = VT->getElementType();
+  unsigned NumElems = VT->getNumElements();
+
+  assert(ElemType == Arg1Type->castAs<VectorType>()->getElementType());
+  assert(NumElems == Arg1Type->castAs<VectorType>()->getNumElements());
+  assert(ElemType->isIntegerType() || ElemType->isRealFloatingType());
 
   const Pointer &RHS = S.Stk.pop<Pointer>();
   const Pointer &LHS = S.Stk.pop<Pointer>();
   const Pointer &Dst = S.Stk.peek<Pointer>();
-  PrimType ElemT = *S.getContext().classify(VT->getElementType());
-  unsigned NumElems = VT->getNumElements();
+  PrimType ElemT = *S.getContext().classify(ElemType);
   for (unsigned I = 0; I != NumElems; ++I) {
-    APSInt Elem1;
-    APSInt Elem2;
-    INT_TYPE_SWITCH_NO_BOOL(ElemT, {
-      Elem1 = LHS.elem<T>(I).toAPSInt();
-      Elem2 = RHS.elem<T>(I).toAPSInt();
-    });
+    if (ElemType->isIntegerType()) {
+      APSInt LHSInt;
+      APSInt RHSInt;
+      INT_TYPE_SWITCH_NO_BOOL(ElemT, {
+        LHSInt = LHS.elem<T>(I).toAPSInt();
+        RHSInt = RHS.elem<T>(I).toAPSInt();
+      });
+      
+      APSInt Result;
+      if (BuiltinID == Builtin::BI__builtin_elementwise_max) {
+        Result = std::max(LHSInt, RHSInt);
+      } else if (BuiltinID == Builtin::BI__builtin_elementwise_min) {
+        Result = std::min(LHSInt, RHSInt);
+      } else {
+        llvm_unreachable("Wrong builtin ID");
+      }
 
-    APSInt Result;
-    if (BuiltinID == Builtin::BI__builtin_elementwise_max) {
-      Result = APSInt(std::max(Elem1, Elem2),
-                      Call->getType()->isUnsignedIntegerOrEnumerationType());
-    } else if (BuiltinID == Builtin::BI__builtin_elementwise_min) {
-      Result = APSInt(std::min(Elem1, Elem2),
-                      Call->getType()->isUnsignedIntegerOrEnumerationType());
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(I) = static_cast<T>(Result); });
     } else {
-      llvm_unreachable("Wrong builtin ID");
+      APFloat RHSFloat = RHS.elem<Floating>(I).getAPFloat();
+      APFloat LHSFloat = LHS.elem<Floating>(I).getAPFloat();
+      Floating Result = S.allocFloat(RHSFloat.getSemantics());
+      if (BuiltinID == Builtin::BI__builtin_elementwise_max) {
+        Result.copy(maxnum(LHSFloat, RHSFloat));
+      } else if (BuiltinID == Builtin::BI__builtin_elementwise_min) {
+        Result.copy(minnum(LHSFloat, RHSFloat));
+      } else {
+        llvm_unreachable("Wrong builtin ID");
+      }
+      Dst.elem<Floating>(I) = Result;
     }
-
-    INT_TYPE_SWITCH_NO_BOOL(ElemT,
-                            { Dst.elem<T>(I) = static_cast<T>(Result); });
   }
   Dst.initializeAllElements();
 
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 36dd0f5d7a065..b232cd4a74649 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11697,28 +11697,40 @@ bool VectorExprEvaluator::VisitCallExpr(const 
CallExpr *E) {
 
     QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
 
-    if (!DestEltTy->isIntegerType())
-      return false;
-
     unsigned SourceLen = SourceLHS.getVectorLength();
     SmallVector<APValue, 4> ResultElements;
     ResultElements.reserve(SourceLen);
 
     for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
-      APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
-      APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
-      switch (E->getBuiltinCallee()) {
-      case Builtin::BI__builtin_elementwise_max:
-        ResultElements.push_back(
-            APValue(APSInt(std::max(LHS, RHS),
-                           DestEltTy->isUnsignedIntegerOrEnumerationType())));
-        break;
-      case Builtin::BI__builtin_elementwise_min:
-        ResultElements.push_back(
-            APValue(APSInt(std::min(LHS, RHS),
-                           DestEltTy->isUnsignedIntegerOrEnumerationType())));
-        break;
+      APValue LHS = SourceLHS.getVectorElt(EltNum);
+      APValue RHS = SourceRHS.getVectorElt(EltNum);
+      APValue ResultElt;
+      if (DestEltTy->isIntegerType()) {
+        APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
+        APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
+        switch (E->getBuiltinCallee()) {
+        case Builtin::BI__builtin_elementwise_max:
+          ResultElt = APValue(APSInt(std::max(LHS, RHS),
+                                      
DestEltTy->isUnsignedIntegerOrEnumerationType()));
+          break;
+        case Builtin::BI__builtin_elementwise_min:
+          ResultElt = APValue(APSInt(std::min(LHS, RHS),
+                                      
DestEltTy->isUnsignedIntegerOrEnumerationType()));
+          break;
+        }
+      } else if (DestEltTy->isRealFloatingType()) {
+        APFloat LHS = SourceLHS.getVectorElt(EltNum).getFloat();
+        APFloat RHS = SourceRHS.getVectorElt(EltNum).getFloat();
+        switch (E->getBuiltinCallee()) {
+        case Builtin::BI__builtin_elementwise_max:
+          ResultElt = APValue(maxnum(LHS, RHS));
+          break;
+        case Builtin::BI__builtin_elementwise_min:
+          ResultElt = APValue(minnum(LHS, RHS));
+          break;
+        }
       }
+      ResultElements.push_back(std::move(ResultElt));
     }
 
     return Success(APValue(ResultElements.data(), ResultElements.size()), E);
@@ -15917,7 +15929,8 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
   case Builtin::BI__builtin_fmaxf:
   case Builtin::BI__builtin_fmaxl:
   case Builtin::BI__builtin_fmaxf16:
-  case Builtin::BI__builtin_fmaxf128: {
+  case Builtin::BI__builtin_fmaxf128: 
+  case Builtin::BI__builtin_elementwise_max: {
     APFloat RHS(0.);
     if (!EvaluateFloat(E->getArg(0), Result, Info) ||
         !EvaluateFloat(E->getArg(1), RHS, Info))
@@ -15930,7 +15943,8 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
   case Builtin::BI__builtin_fminf:
   case Builtin::BI__builtin_fminl:
   case Builtin::BI__builtin_fminf16:
-  case Builtin::BI__builtin_fminf128: {
+  case Builtin::BI__builtin_fminf128:
+  case Builtin::BI__builtin_elementwise_min: {
     APFloat RHS(0.);
     if (!EvaluateFloat(E->getArg(0), Result, Info) ||
         !EvaluateFloat(E->getArg(1), RHS, Info))
diff --git a/clang/test/Sema/constant-builtins-vector.cpp 
b/clang/test/Sema/constant-builtins-vector.cpp
index bc575dca98d77..06f1067409dca 100644
--- a/clang/test/Sema/constant-builtins-vector.cpp
+++ b/clang/test/Sema/constant-builtins-vector.cpp
@@ -865,6 +865,8 @@ static_assert(__builtin_elementwise_max(1, 2) == 2);
 static_assert(__builtin_elementwise_max(-1, 1) == 1);
 static_assert(__builtin_elementwise_max(1U, 2U) == 2U);
 static_assert(__builtin_elementwise_max(~0U, 0U) == ~0U);
+static_assert(__builtin_fabs(__builtin_elementwise_max(1.0f, 2.0f) - 2.0f) < 
1e-6);
+static_assert(__builtin_fabs(__builtin_elementwise_max(-1.0f, 1.0f) - 1.0f) < 
1e-6);
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_max((vector4char){1, -2, 3, -4}, (vector4char){4, -3, 2, 
-1})) == (LITTLE_END ? 0xFF03FE04 : 0x04FE03FF ));
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_max((vector4uchar){1, 2, 3, 4}, (vector4uchar){4, 3, 2, 
1})) == 0x04030304U);
 static_assert(__builtin_bit_cast(unsigned long long, 
__builtin_elementwise_max((vector4short){1, -2, 3, -4}, (vector4short){4, -3, 
2, -1})) == (LITTLE_END ? 0xFFFF0003FFFE0004 : 0x0004FFFE0003FFFF));
@@ -873,6 +875,27 @@ static_assert(__builtin_elementwise_min(1, 2) == 1);
 static_assert(__builtin_elementwise_min(-1, 1) == -1);
 static_assert(__builtin_elementwise_min(1U, 2U) == 1U);
 static_assert(__builtin_elementwise_min(~0U, 0U) == 0U);
+static_assert(__builtin_fabs(__builtin_elementwise_min(1.0f, 2.0f) - 1.0f) < 
1e-6);
+static_assert(__builtin_fabs(__builtin_elementwise_min(-1.0f, 1.0f) - (-1.0f)) 
< 1e-6);
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_min((vector4char){1, -2, 3, -4}, (vector4char){4, -3, 2, 
-1})) == (LITTLE_END ? 0xFC02FD01 : 0x01FD02FC));
 static_assert(__builtin_bit_cast(unsigned, 
__builtin_elementwise_min((vector4uchar){1, 2, 3, 4}, (vector4uchar){4, 3, 2, 
1})) == 0x01020201U);
 static_assert(__builtin_bit_cast(unsigned long long, 
__builtin_elementwise_min((vector4short){1, -2, 3, -4}, (vector4short){4, -3, 
2, -1})) == (LITTLE_END ? 0xFFFC0002FFFD0001 : 0x0001FFFD0002FFFC));
+
+#define CHECK_VECTOR4_FLOAT_EQ(v1, v2) \
+    static_assert(__builtin_fabs((v1)[0] - (v2)[0]) < 1e-6 &&      \
+                  __builtin_fabs((v1)[1] - (v2)[1]) < 1e-6 &&      \
+                  __builtin_fabs((v1)[2] - (v2)[2]) < 1e-6 &&      \
+                  __builtin_fabs((v1)[3] - (v2)[3]) < 1e-6);
+CHECK_VECTOR4_FLOAT_EQ(
+    (__builtin_elementwise_max((vector4float){1.0f, -2.0f, 3.0f, -4.0f}, 
(vector4float){4.0f, -3.0f, 2.0f, -1.0f})),
+    ((vector4float){4.0f, -2.0f, 3.0f, -1.0f}))
+CHECK_VECTOR4_FLOAT_EQ(
+    (__builtin_elementwise_max((vector4double){1.0f, -2.0f, 3.0f, -4.0f}, 
(vector4double){4.0f, -3.0f, 2.0f, -1.0f})),
+    ((vector4double){4.0f, -2.0f, 3.0f, -1.0f}))
+CHECK_VECTOR4_FLOAT_EQ(
+    (__builtin_elementwise_min((vector4float){1.0f, -2.0f, 3.0f, -4.0f}, 
(vector4float){4.0f, -3.0f, 2.0f, -1.0f})),
+    ((vector4float){1.0f, -3.0f, 2.0f, -4.0f}))
+CHECK_VECTOR4_FLOAT_EQ(
+    (__builtin_elementwise_max((vector4double){1.0f, -2.0f, 3.0f, -4.0f}, 
(vector4double){4.0f, -3.0f, 2.0f, -1.0f})),
+    ((vector4double){4.0f, -2.0f, 3.0f, -1.0f}))
+#undef CHECK_VECTOR4_FLOAT_EQ

``````````

</details>


https://github.com/llvm/llvm-project/pull/153563
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to