https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/156017
Allows us to handle OP(vector, scalar) cases such as x86 shift/rotate by immediate I've included a little of the x86 vector rotate by immediate constexpr code to show how I reckon this will work CC @Arghnews - if I've gotten this right you should be able to add the shift-by-immediate builtins ids to share the code from the appropriate shl/lshr/ashr shift-by-variable cases. >From 6c66244d28c53c9f74c3f159f985595cac7252fd Mon Sep 17 00:00:00 2001 From: Simon Pilgrim <[email protected]> Date: Fri, 29 Aug 2025 13:29:00 +0100 Subject: [PATCH] [Clang][bytecode] interp__builtin_elementwise_int_binop - allow RHS operand to be a scalar Allows us to handle OP(vector, scalar) cases such as x86 shift/rotate by immediate --- clang/include/clang/Basic/BuiltinsX86.td | 17 +++-- clang/lib/AST/ByteCode/InterpBuiltin.cpp | 74 ++++++++++++++++++---- clang/lib/AST/ExprConstant.cpp | 58 +++++++++++++++++ clang/test/CodeGen/X86/avx512vl-builtins.c | 1 + clang/test/CodeGen/X86/xop-builtins.c | 1 + 5 files changed, 133 insertions(+), 18 deletions(-) diff --git a/clang/include/clang/Basic/BuiltinsX86.td b/clang/include/clang/Basic/BuiltinsX86.td index 5874aee6c83fc..11c72749ca511 100644 --- a/clang/include/clang/Basic/BuiltinsX86.td +++ b/clang/include/clang/Basic/BuiltinsX86.td @@ -925,10 +925,6 @@ let Features = "xop", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in def vphsubwd : X86Builtin<"_Vector<4, int>(_Vector<8, short>)">; def vphsubdq : X86Builtin<"_Vector<2, long long int>(_Vector<4, int>)">; def vpperm : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Vector<16, char>)">; - def vprotbi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Constant char)">; - def vprotwi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Constant char)">; - def vprotdi : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant char)">; - def vprotqi : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant char)">; def vpshlb : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>)">; def vpshlw : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">; def vpshld : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">; @@ -953,6 +949,13 @@ let Features = "xop", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in def vfrczpd : X86Builtin<"_Vector<2, double>(_Vector<2, double>)">; } +let Features = "xop", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in { + def vprotbi : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Constant char)">; + def vprotwi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Constant char)">; + def vprotdi : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant char)">; + def vprotqi : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant char)">; +} + let Features = "xop", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in { def vpermil2pd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Vector<4, double>, _Vector<4, long long int>, _Constant char)">; def vpermil2ps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<8, float>, _Vector<8, int>, _Constant char)">; @@ -2036,21 +2039,21 @@ let Features = "avx512dq,evex512", Attributes = [NoThrow, Const, RequiredVectorW def reduceps512_mask : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Constant int, _Vector<16, float>, unsigned short, _Constant int)">; } -let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in { +let Features = "avx512f,evex512", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in { def prold512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Constant int)">; def prord512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Constant int)">; def prolq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Constant int)">; def prorq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Constant int)">; } -let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in { +let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in { def prold128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant int)">; def prord128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Constant int)">; def prolq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant int)">; def prorq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant int)">; } -let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in { +let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in { def prold256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Constant int)">; def prord256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Constant int)">; def prolq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">; diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp index eba0b25997699..c4f4e1c2daa03 100644 --- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp +++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp @@ -2565,21 +2565,48 @@ static bool interp__builtin_elementwise_int_binop( return true; } + const auto *VT = Call->getArg(0)->getType()->castAs<VectorType>(); + assert(VT->getElementType()->isIntegralOrEnumerationType()); + PrimType ElemT = *S.getContext().classify(VT->getElementType()); + unsigned NumElems = VT->getNumElements(); + + // Vector + Scalar case. + if (!Call->getArg(1)->getType()->isVectorType()) { + assert(Call->getArg(1)->getType()->isIntegralOrEnumerationType()); + + APSInt RHS = popToAPSInt( + S.Stk, *S.getContext().classify(Call->getArg(1)->getType())); + const Pointer &LHS = S.Stk.pop<Pointer>(); + const Pointer &Dst = S.Stk.peek<Pointer>(); + + for (unsigned I = 0; I != NumElems; ++I) { + APSInt Elem1; + INT_TYPE_SWITCH_NO_BOOL(ElemT, { + Elem1 = LHS.elem<T>(I).toAPSInt(); + }); + + APSInt Result = + APSInt(Fn(Elem1, RHS), + Call->getType()->isUnsignedIntegerOrEnumerationType()); + + INT_TYPE_SWITCH_NO_BOOL(ElemT, + { Dst.elem<T>(I) = static_cast<T>(Result); }); + } + Dst.initializeAllElements(); + return true; + } + // Vector case. assert(Call->getArg(0)->getType()->isVectorType() && Call->getArg(1)->getType()->isVectorType()); - const auto *VT = Call->getArg(0)->getType()->castAs<VectorType>(); assert(VT->getElementType() == Call->getArg(1)->getType()->castAs<VectorType>()->getElementType()); assert(VT->getNumElements() == Call->getArg(1)->getType()->castAs<VectorType>()->getNumElements()); - assert(VT->getElementType()->isIntegralOrEnumerationType()); const Pointer &RHS = S.Stk.pop<Pointer>(); const Pointer &LHS = S.Stk.pop<Pointer>(); const Pointer &Dst = S.Stk.peek<Pointer>(); - PrimType ElemT = *S.getContext().classify(VT->getElementType()); - unsigned NumElems = VT->getNumElements(); for (unsigned I = 0; I != NumElems; ++I) { APSInt Elem1; APSInt Elem2; @@ -2596,7 +2623,6 @@ static bool interp__builtin_elementwise_int_binop( { Dst.elem<T>(I) = static_cast<T>(Result); }); } Dst.initializeAllElements(); - return true; } @@ -3256,8 +3282,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case clang::X86::BI__builtin_ia32_psllv8si: return interp__builtin_elementwise_int_binop( S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) { - if (RHS.uge(RHS.getBitWidth())) { - return APInt::getZero(RHS.getBitWidth()); + if (RHS.uge(LHS.getBitWidth())) { + return APInt::getZero(LHS.getBitWidth()); } return LHS.shl(RHS.getZExtValue()); }); @@ -3266,8 +3292,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case clang::X86::BI__builtin_ia32_psrav8si: return interp__builtin_elementwise_int_binop( S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) { - if (RHS.uge(RHS.getBitWidth())) { - return LHS.ashr(RHS.getBitWidth() - 1); + if (RHS.uge(LHS.getBitWidth())) { + return LHS.ashr(LHS.getBitWidth() - 1); } return LHS.ashr(RHS.getZExtValue()); }); @@ -3278,12 +3304,38 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case clang::X86::BI__builtin_ia32_psrlv8si: return interp__builtin_elementwise_int_binop( S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) { - if (RHS.uge(RHS.getBitWidth())) { - return APInt::getZero(RHS.getBitWidth()); + if (RHS.uge(LHS.getBitWidth())) { + return APInt::getZero(LHS.getBitWidth()); } return LHS.lshr(RHS.getZExtValue()); }); + case clang::X86::BI__builtin_ia32_vprotbi: + case clang::X86::BI__builtin_ia32_vprotdi: + case clang::X86::BI__builtin_ia32_vprotqi: + case clang::X86::BI__builtin_ia32_vprotwi: + case clang::X86::BI__builtin_ia32_prold128: + case clang::X86::BI__builtin_ia32_prold256: + case clang::X86::BI__builtin_ia32_prold512: + case clang::X86::BI__builtin_ia32_prolq128: + case clang::X86::BI__builtin_ia32_prolq256: + case clang::X86::BI__builtin_ia32_prolq512: + return interp__builtin_elementwise_int_binop( + S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) { + return LHS.rotl(RHS.urem(LHS.getBitWidth())); + }); + + case clang::X86::BI__builtin_ia32_prord128: + case clang::X86::BI__builtin_ia32_prord256: + case clang::X86::BI__builtin_ia32_prord512: + case clang::X86::BI__builtin_ia32_prorq128: + case clang::X86::BI__builtin_ia32_prorq256: + case clang::X86::BI__builtin_ia32_prorq512: + return interp__builtin_elementwise_int_binop( + S, OpPC, Call, BuiltinID, [](const APSInt &LHS, const APSInt &RHS) { + return LHS.rotr(RHS.urem(LHS.getBitWidth())); + }); + case Builtin::BI__builtin_elementwise_max: case Builtin::BI__builtin_elementwise_min: return interp__builtin_elementwise_maxmin(S, OpPC, Call, BuiltinID); diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index e7dc1d1ca6c27..b318540bba1e3 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -11746,6 +11746,64 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) { return Success(APValue(ResultElements.data(), ResultElements.size()), E); } + case clang::X86::BI__builtin_ia32_vprotbi: + case clang::X86::BI__builtin_ia32_vprotdi: + case clang::X86::BI__builtin_ia32_vprotqi: + case clang::X86::BI__builtin_ia32_vprotwi: + case clang::X86::BI__builtin_ia32_prold128: + case clang::X86::BI__builtin_ia32_prold256: + case clang::X86::BI__builtin_ia32_prold512: + case clang::X86::BI__builtin_ia32_prolq128: + case clang::X86::BI__builtin_ia32_prolq256: + case clang::X86::BI__builtin_ia32_prolq512: { + APValue SourceLHS, SourceRHS; + if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) || + !EvaluateAsRValue(Info, E->getArg(1), SourceRHS)) + return false; + + QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType(); + bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType(); + unsigned SourceLen = SourceLHS.getVectorLength(); + SmallVector<APValue, 4> ResultElements; + ResultElements.reserve(SourceLen); + + APSInt RHS = SourceRHS.getInt(); + + for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) { + APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt(); + ResultElements.push_back( + APValue(APSInt(LHS.rotl(RHS.urem(LHS.getBitWidth())), DestUnsigned))); + } + + return Success(APValue(ResultElements.data(), ResultElements.size()), E); + } + case clang::X86::BI__builtin_ia32_prord128: + case clang::X86::BI__builtin_ia32_prord256: + case clang::X86::BI__builtin_ia32_prord512: + case clang::X86::BI__builtin_ia32_prorq128: + case clang::X86::BI__builtin_ia32_prorq256: + case clang::X86::BI__builtin_ia32_prorq512: { + APValue SourceLHS, SourceRHS; + if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) || + !EvaluateAsRValue(Info, E->getArg(1), SourceRHS)) + return false; + + QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType(); + bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType(); + unsigned SourceLen = SourceLHS.getVectorLength(); + SmallVector<APValue, 4> ResultElements; + ResultElements.reserve(SourceLen); + + APSInt RHS = SourceRHS.getInt(); + + for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) { + APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt(); + ResultElements.push_back( + APValue(APSInt(LHS.rotr(RHS.urem(LHS.getBitWidth())), DestUnsigned))); + } + + return Success(APValue(ResultElements.data(), ResultElements.size()), E); + } case Builtin::BI__builtin_elementwise_max: case Builtin::BI__builtin_elementwise_min: { APValue SourceLHS, SourceRHS; diff --git a/clang/test/CodeGen/X86/avx512vl-builtins.c b/clang/test/CodeGen/X86/avx512vl-builtins.c index 7043927185a3a..c9fa15e7fa571 100644 --- a/clang/test/CodeGen/X86/avx512vl-builtins.c +++ b/clang/test/CodeGen/X86/avx512vl-builtins.c @@ -5666,6 +5666,7 @@ __m128i test_mm_rol_epi32(__m128i __A) { // CHECK: @llvm.fshl.v4i32 return _mm_rol_epi32(__A, 5); } +TEST_CONSTEXPR(match_v4si(_mm_rol_epi32(((__m128i)(__v4si){1, -2, 3, -4}), 5), 32, -33, 96, -97)); __m128i test_mm_mask_rol_epi32(__m128i __W, __mmask8 __U, __m128i __A) { // CHECK-LABEL: test_mm_mask_rol_epi32 diff --git a/clang/test/CodeGen/X86/xop-builtins.c b/clang/test/CodeGen/X86/xop-builtins.c index cd403d5876fa3..718539bb0eef7 100644 --- a/clang/test/CodeGen/X86/xop-builtins.c +++ b/clang/test/CodeGen/X86/xop-builtins.c @@ -251,6 +251,7 @@ __m128i test_mm_roti_epi32(__m128i a) { // CHECK: call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> splat (i32 226)) return _mm_roti_epi32(a, -30); } +TEST_CONSTEXPR(match_v4si(_mm_roti_epi32(((__m128i)(__v4si){1, -2, 3, -4}), 5), 32, -33, 96, -97)); __m128i test_mm_roti_epi64(__m128i a) { // CHECK-LABEL: test_mm_roti_epi64 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
