https://github.com/RKSimon created 
https://github.com/llvm/llvm-project/pull/156017

Allows us to handle OP(vector, scalar) cases such as x86 shift/rotate by 
immediate

I've included a little of the x86 vector rotate by immediate constexpr code to 
show how I reckon this will work

CC @Arghnews - if I've gotten this right you should be able to add the 
shift-by-immediate builtins ids to share the code from the appropriate 
shl/lshr/ashr shift-by-variable cases.

>From 6c66244d28c53c9f74c3f159f985595cac7252fd Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <[email protected]>
Date: Fri, 29 Aug 2025 13:29:00 +0100
Subject: [PATCH] [Clang][bytecode] interp__builtin_elementwise_int_binop -
 allow RHS operand to be a scalar

Allows us to handle OP(vector, scalar) cases such as x86 shift/rotate by 
immediate
---
 clang/include/clang/Basic/BuiltinsX86.td   | 17 +++--
 clang/lib/AST/ByteCode/InterpBuiltin.cpp   | 74 ++++++++++++++++++----
 clang/lib/AST/ExprConstant.cpp             | 58 +++++++++++++++++
 clang/test/CodeGen/X86/avx512vl-builtins.c |  1 +
 clang/test/CodeGen/X86/xop-builtins.c      |  1 +
 5 files changed, 133 insertions(+), 18 deletions(-)

diff --git a/clang/include/clang/Basic/BuiltinsX86.td 
b/clang/include/clang/Basic/BuiltinsX86.td
index 5874aee6c83fc..11c72749ca511 100644
--- a/clang/include/clang/Basic/BuiltinsX86.td
+++ b/clang/include/clang/Basic/BuiltinsX86.td
@@ -925,10 +925,6 @@ let Features = "xop", Attributes = [NoThrow, Const, 
RequiredVectorWidth<128>] in
   def vphsubwd : X86Builtin<"_Vector<4, int>(_Vector<8, short>)">;
   def vphsubdq : X86Builtin<"_Vector<2, long long int>(_Vector<4, int>)">;
   def vpperm : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, 
char>, _Vector<16, char>)">;
-  def vprotbi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Constant 
char)">;
-  def vprotwi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Constant 
char)">;
-  def vprotdi : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant char)">;
-  def vprotqi : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long 
int>, _Constant char)">;
   def vpshlb : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, 
char>)">;
   def vpshlw : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, 
short>)">;
   def vpshld : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
@@ -953,6 +949,13 @@ let Features = "xop", Attributes = [NoThrow, Const, 
RequiredVectorWidth<128>] in
   def vfrczpd : X86Builtin<"_Vector<2, double>(_Vector<2, double>)">;
 }
 
+let Features = "xop", Attributes = [NoThrow, Const, Constexpr, 
RequiredVectorWidth<128>] in {
+  def vprotbi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Constant 
char)">;
+  def vprotwi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Constant 
char)">;
+  def vprotdi : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant char)">;
+  def vprotqi : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long 
int>, _Constant char)">;
+}
+
 let Features = "xop", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] 
in {
   def vpermil2pd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, 
_Vector<4, double>, _Vector<4, long long int>, _Constant char)">;
   def vpermil2ps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, 
_Vector<8, float>, _Vector<8, int>, _Constant char)">;
@@ -2036,21 +2039,21 @@ let Features = "avx512dq,evex512", Attributes = 
[NoThrow, Const, RequiredVectorW
   def reduceps512_mask : X86Builtin<"_Vector<16, float>(_Vector<16, float>, 
_Constant int, _Vector<16, float>, unsigned short, _Constant int)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, 
RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512", Attributes = [NoThrow, Const, Constexpr, 
RequiredVectorWidth<512>] in {
   def prold512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Constant 
int)">;
   def prord512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Constant 
int)">;
   def prolq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long 
int>, _Constant int)">;
   def prorq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long 
int>, _Constant int)">;
 }
 
-let Features = "avx512vl", Attributes = [NoThrow, Const, 
RequiredVectorWidth<128>] in {
+let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, 
RequiredVectorWidth<128>] in {
   def prold128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant int)">;
   def prord128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant int)">;
   def prolq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long 
int>, _Constant int)">;
   def prorq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long 
int>, _Constant int)">;
 }
 
-let Features = "avx512vl", Attributes = [NoThrow, Const, 
RequiredVectorWidth<256>] in {
+let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, 
RequiredVectorWidth<256>] in {
   def prold256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Constant int)">;
   def prord256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Constant int)">;
   def prolq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long 
int>, _Constant int)">;
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp 
b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index eba0b25997699..c4f4e1c2daa03 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -2565,21 +2565,48 @@ static bool interp__builtin_elementwise_int_binop(
     return true;
   }
 
+  const auto *VT = Call->getArg(0)->getType()->castAs<VectorType>();
+  assert(VT->getElementType()->isIntegralOrEnumerationType());
+  PrimType ElemT = *S.getContext().classify(VT->getElementType());
+  unsigned NumElems = VT->getNumElements();
+
+  // Vector + Scalar case.
+  if (!Call->getArg(1)->getType()->isVectorType()) {
+    assert(Call->getArg(1)->getType()->isIntegralOrEnumerationType());
+
+    APSInt RHS = popToAPSInt(
+        S.Stk, *S.getContext().classify(Call->getArg(1)->getType()));
+    const Pointer &LHS = S.Stk.pop<Pointer>();
+    const Pointer &Dst = S.Stk.peek<Pointer>();
+
+  for (unsigned I = 0; I != NumElems; ++I) {
+      APSInt Elem1;
+      INT_TYPE_SWITCH_NO_BOOL(ElemT, {
+        Elem1 = LHS.elem<T>(I).toAPSInt();
+      });
+
+      APSInt Result =
+          APSInt(Fn(Elem1, RHS),
+                 Call->getType()->isUnsignedIntegerOrEnumerationType());
+
+      INT_TYPE_SWITCH_NO_BOOL(ElemT,
+                              { Dst.elem<T>(I) = static_cast<T>(Result); });
+    }
+    Dst.initializeAllElements();
+    return true;
+  }
+
   // Vector case.
   assert(Call->getArg(0)->getType()->isVectorType() &&
          Call->getArg(1)->getType()->isVectorType());
-  const auto *VT = Call->getArg(0)->getType()->castAs<VectorType>();
   assert(VT->getElementType() ==
          Call->getArg(1)->getType()->castAs<VectorType>()->getElementType());
   assert(VT->getNumElements() ==
          Call->getArg(1)->getType()->castAs<VectorType>()->getNumElements());
-  assert(VT->getElementType()->isIntegralOrEnumerationType());
 
   const Pointer &RHS = S.Stk.pop<Pointer>();
   const Pointer &LHS = S.Stk.pop<Pointer>();
   const Pointer &Dst = S.Stk.peek<Pointer>();
-  PrimType ElemT = *S.getContext().classify(VT->getElementType());
-  unsigned NumElems = VT->getNumElements();
   for (unsigned I = 0; I != NumElems; ++I) {
     APSInt Elem1;
     APSInt Elem2;
@@ -2596,7 +2623,6 @@ static bool interp__builtin_elementwise_int_binop(
                             { Dst.elem<T>(I) = static_cast<T>(Result); });
   }
   Dst.initializeAllElements();
-
   return true;
 }
 
@@ -3256,8 +3282,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const 
CallExpr *Call,
   case clang::X86::BI__builtin_ia32_psllv8si:
     return interp__builtin_elementwise_int_binop(
         S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) {
-          if (RHS.uge(RHS.getBitWidth())) {
-            return APInt::getZero(RHS.getBitWidth());
+          if (RHS.uge(LHS.getBitWidth())) {
+            return APInt::getZero(LHS.getBitWidth());
           }
           return LHS.shl(RHS.getZExtValue());
         });
@@ -3266,8 +3292,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const 
CallExpr *Call,
   case clang::X86::BI__builtin_ia32_psrav8si:
     return interp__builtin_elementwise_int_binop(
         S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) {
-          if (RHS.uge(RHS.getBitWidth())) {
-            return LHS.ashr(RHS.getBitWidth() - 1);
+          if (RHS.uge(LHS.getBitWidth())) {
+            return LHS.ashr(LHS.getBitWidth() - 1);
           }
           return LHS.ashr(RHS.getZExtValue());
         });
@@ -3278,12 +3304,38 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, 
const CallExpr *Call,
   case clang::X86::BI__builtin_ia32_psrlv8si:
     return interp__builtin_elementwise_int_binop(
         S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) {
-          if (RHS.uge(RHS.getBitWidth())) {
-            return APInt::getZero(RHS.getBitWidth());
+          if (RHS.uge(LHS.getBitWidth())) {
+            return APInt::getZero(LHS.getBitWidth());
           }
           return LHS.lshr(RHS.getZExtValue());
         });
 
+  case clang::X86::BI__builtin_ia32_vprotbi:
+  case clang::X86::BI__builtin_ia32_vprotdi:
+  case clang::X86::BI__builtin_ia32_vprotqi:
+  case clang::X86::BI__builtin_ia32_vprotwi:
+  case clang::X86::BI__builtin_ia32_prold128:
+  case clang::X86::BI__builtin_ia32_prold256:
+  case clang::X86::BI__builtin_ia32_prold512:
+  case clang::X86::BI__builtin_ia32_prolq128:
+  case clang::X86::BI__builtin_ia32_prolq256:
+  case clang::X86::BI__builtin_ia32_prolq512:
+    return interp__builtin_elementwise_int_binop(
+        S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) {
+          return LHS.rotl(RHS.urem(LHS.getBitWidth()));
+        });
+
+  case clang::X86::BI__builtin_ia32_prord128:
+  case clang::X86::BI__builtin_ia32_prord256:
+  case clang::X86::BI__builtin_ia32_prord512:
+  case clang::X86::BI__builtin_ia32_prorq128:
+  case clang::X86::BI__builtin_ia32_prorq256:
+  case clang::X86::BI__builtin_ia32_prorq512:
+    return interp__builtin_elementwise_int_binop(
+        S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) {
+          return LHS.rotr(RHS.urem(LHS.getBitWidth()));
+        });
+
   case Builtin::BI__builtin_elementwise_max:
   case Builtin::BI__builtin_elementwise_min:
     return interp__builtin_elementwise_maxmin(S, OpPC, Call, BuiltinID);
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index e7dc1d1ca6c27..b318540bba1e3 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11746,6 +11746,64 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
 
     return Success(APValue(ResultElements.data(), ResultElements.size()), E);
   }
+  case clang::X86::BI__builtin_ia32_vprotbi:
+  case clang::X86::BI__builtin_ia32_vprotdi:
+  case clang::X86::BI__builtin_ia32_vprotqi:
+  case clang::X86::BI__builtin_ia32_vprotwi:
+  case clang::X86::BI__builtin_ia32_prold128:
+  case clang::X86::BI__builtin_ia32_prold256:
+  case clang::X86::BI__builtin_ia32_prold512:
+  case clang::X86::BI__builtin_ia32_prolq128:
+  case clang::X86::BI__builtin_ia32_prolq256:
+  case clang::X86::BI__builtin_ia32_prolq512: {
+    APValue SourceLHS, SourceRHS;
+    if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
+        !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
+      return false;
+
+    QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
+    bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
+    unsigned SourceLen = SourceLHS.getVectorLength();
+    SmallVector<APValue, 4> ResultElements;
+    ResultElements.reserve(SourceLen);
+
+    APSInt RHS = SourceRHS.getInt();
+
+    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+      APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
+      ResultElements.push_back(
+          APValue(APSInt(LHS.rotl(RHS.urem(LHS.getBitWidth())), 
DestUnsigned)));
+    }
+
+    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+  }
+  case clang::X86::BI__builtin_ia32_prord128:
+  case clang::X86::BI__builtin_ia32_prord256:
+  case clang::X86::BI__builtin_ia32_prord512:
+  case clang::X86::BI__builtin_ia32_prorq128:
+  case clang::X86::BI__builtin_ia32_prorq256:
+  case clang::X86::BI__builtin_ia32_prorq512: {
+    APValue SourceLHS, SourceRHS;
+    if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
+        !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
+      return false;
+
+    QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
+    bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
+    unsigned SourceLen = SourceLHS.getVectorLength();
+    SmallVector<APValue, 4> ResultElements;
+    ResultElements.reserve(SourceLen);
+
+    APSInt RHS = SourceRHS.getInt();
+
+    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+      APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
+      ResultElements.push_back(
+          APValue(APSInt(LHS.rotr(RHS.urem(LHS.getBitWidth())), 
DestUnsigned)));
+    }
+
+    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+  }
   case Builtin::BI__builtin_elementwise_max:
   case Builtin::BI__builtin_elementwise_min: {
     APValue SourceLHS, SourceRHS;
diff --git a/clang/test/CodeGen/X86/avx512vl-builtins.c 
b/clang/test/CodeGen/X86/avx512vl-builtins.c
index 7043927185a3a..c9fa15e7fa571 100644
--- a/clang/test/CodeGen/X86/avx512vl-builtins.c
+++ b/clang/test/CodeGen/X86/avx512vl-builtins.c
@@ -5666,6 +5666,7 @@ __m128i test_mm_rol_epi32(__m128i __A) {
   // CHECK: @llvm.fshl.v4i32
   return _mm_rol_epi32(__A, 5); 
 }
+TEST_CONSTEXPR(match_v4si(_mm_rol_epi32(((__m128i)(__v4si){1, -2, 3, -4}), 5), 
32, -33, 96, -97));
 
 __m128i test_mm_mask_rol_epi32(__m128i __W, __mmask8 __U, __m128i __A) {
   // CHECK-LABEL: test_mm_mask_rol_epi32
diff --git a/clang/test/CodeGen/X86/xop-builtins.c 
b/clang/test/CodeGen/X86/xop-builtins.c
index cd403d5876fa3..718539bb0eef7 100644
--- a/clang/test/CodeGen/X86/xop-builtins.c
+++ b/clang/test/CodeGen/X86/xop-builtins.c
@@ -251,6 +251,7 @@ __m128i test_mm_roti_epi32(__m128i a) {
   // CHECK: call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %{{.*}}, <4 x i32> 
%{{.*}}, <4 x i32> splat (i32 226))
   return _mm_roti_epi32(a, -30);
 }
+TEST_CONSTEXPR(match_v4si(_mm_roti_epi32(((__m128i)(__v4si){1, -2, 3, -4}), 
5), 32, -33, 96, -97));
 
 __m128i test_mm_roti_epi64(__m128i a) {
   // CHECK-LABEL: test_mm_roti_epi64

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to