[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
addmisol wrote: thanks @arsenm, also please assign me any issues that need to be solved for AMDGPU.. https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
addmisol wrote: Thanks @arsenm, can you please merge this if this also, currently I don't have write access. https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2405,6 +2412,148 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Check if type is <4 x i8>.
+static bool isV4I8(Type *Ty) {
+ FixedVectorType *VTy = dyn_cast(Ty);
+ return VTy && VTy->getNumElements() == 4 &&
+ VTy->getElementType()->isIntegerTy(8);
+}
+
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches and signedness matches IsSigned.
+/// Sets A, B to the <4 x i8> sources.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool IsSigned) {
+ Value *Src0, *Src1;
+ if (!match(MulOp, m_Mul(m_Value(Src0), m_Value(Src1
+return false;
+
+ // Check that result type is <4 x i32>
+ FixedVectorType *MulTy = dyn_cast(MulOp->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ // Match zext or sext based on IsSigned
+ Value *ExtSrc0, *ExtSrc1;
+ if (IsSigned) {
+if (!match(Src0, m_SExt(m_Value(ExtSrc0))) || !isV4I8(ExtSrc0->getType()))
+ return false;
+if (!match(Src1, m_SExt(m_Value(ExtSrc1))) || !isV4I8(ExtSrc1->getType()))
+ return false;
+ } else {
+if (!match(Src0, m_ZExt(m_Value(ExtSrc0))) || !isV4I8(ExtSrc0->getType()))
+ return false;
+if (!match(Src1, m_ZExt(m_Value(ExtSrc1))) || !isV4I8(ExtSrc1->getType()))
+ return false;
+ }
+
+ A = ExtSrc0;
+ B = ExtSrc1;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Value *A = nullptr, *B = nullptr;
+
+ // Try unsigned first, then signed
+ bool IsSigned = false;
+ if (!matchDot4Pattern(I.getArgOperand(0), A, B, /*IsSigned=*/false)) {
+if (!matchDot4Pattern(I.getArgOperand(0), A, B, /*IsSigned=*/true))
+ return false;
+IsSigned = true;
+ }
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Non-saturating case: accumulator is 0, clamp is false
+ Value *Acc = ConstantInt::get(I32Ty, 0);
+ Value *Clamp = ConstantInt::getFalse(Ctx);
addmisol wrote:
I tested this ir on latest main:
```
define i32 @test(i32 %a, i32 %b, i32 %acc) {
%dot = call i32 @llvm.amdgcn.udot4(i32 %a, i32 %b, i32 0, i1 false)
%result = call i32 @llvm.uadd.sat.i32(i32 %dot, i32 %acc)
ret i32 %result
}
```
It generates 2 instructions:
```
v_dot4_u32_u8 v0, v0, v1, 0
v_add_u32_e64 v0, v0, v2 clamp
```
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -1129,10 +1131,10 @@ Value *AMDGPUCodeGenPrepareImpl::expandDivRem24Impl( : Builder.CreateFPToUI(FQ, I32Ty); // fr = fabs(fr); - FR = Builder.CreateFAbs(FR, FQ); + FR = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FR, FQ); arsenm wrote: Unrelated change regression use of CreateFAbs https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol edited https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -728,13 +728,49 @@ defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f def : UDot2Pat; def : SDot2Pat; +// Saturating unsigned dot2 pattern: uaddsat(a[0]*b[0] + a[1]*b[1], c) addmisol wrote: the UDot2SatPat and the SDot2SatPat Tablegen patterns are the only way dot2 saturating cases get matched. https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -728,13 +728,49 @@ defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f def : UDot2Pat; def : SDot2Pat; +// Saturating unsigned dot2 pattern: uaddsat(a[0]*b[0] + a[1]*b[1], c) addmisol wrote: No, CGP is not handling the dot2 at all. CGP only handles dot4 patterns (with <4 x i8> + vector.reduce.add). https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol deleted https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol edited https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -728,13 +728,49 @@ defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f def : UDot2Pat; def : SDot2Pat; +// Saturating unsigned dot2 pattern: uaddsat(a[0]*b[0] + a[1]*b[1], c) addmisol wrote: No, CGP does is handling the dot2 at all. CGP only handles dot4 patterns (with <4 x i8> + vector.reduce.add). https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol deleted https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
addmisol wrote: till now with this pr's changes, these dot patterns with are recognized for these compared to the main branch: for dot2 - udot2_sat - sdot2_sat for dot4 - test_udot4_sat - test_sdot4_sat - test_udot4_unsat - test_sdot4_unsat https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -728,13 +728,49 @@ defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f def : UDot2Pat; def : SDot2Pat; +// Saturating unsigned dot2 pattern: uaddsat(a[0]*b[0] + a[1]*b[1], c) addmisol wrote: Let me check this again, getting a bit confused🙂 https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2405,6 +2412,153 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Check if type is <4 x i8>.
+static bool isV4I8(Type *Ty) {
+ FixedVectorType *VTy = dyn_cast(Ty);
+ return VTy && VTy->getNumElements() == 4 &&
+ VTy->getElementType()->isIntegerTy(8);
+}
+
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ Value *Src0, *Src1;
+ if (!match(MulOp, m_Mul(m_Value(Src0), m_Value(Src1
+return false;
+
+ // Check that result type is <4 x i32>
+ FixedVectorType *MulTy = dyn_cast(MulOp->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ Value *ExtSrc0, *ExtSrc1;
+ if (!match(Src0, m_ZExtOrSExt(m_Value(ExtSrc0))) ||
!isV4I8(ExtSrc0->getType()))
+return false;
+ if (!match(Src1, m_ZExtOrSExt(m_Value(ExtSrc1))) ||
!isV4I8(ExtSrc1->getType()))
+return false;
+
+ // Both operands must have the same signedness
+ bool Signed0 = isa(Src0);
+ bool Signed1 = isa(Src1);
+ if (Signed0 != Signed1)
+return false;
+
+ A = ExtSrc0;
+ B = ExtSrc1;
+ IsSigned = Signed0;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Value *A = nullptr, *B = nullptr;
+ bool IsSigned = false;
+
+ if (!matchDot4Pattern(I.getArgOperand(0), A, B, IsSigned))
+return false;
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Non-saturating case: accumulator is 0, clamp is false
+ Value *Acc = ConstantInt::get(I32Ty, 0);
+ Value *Clamp = ConstantInt::getFalse(Ctx);
+
+ Intrinsic::ID DotIID =
+ IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+ Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Acc, Clamp});
+ Dot->takeName(&I);
+
+ I.replaceAllUsesWith(Dot);
+ DeadVals.push_back(&I);
+
+ return true;
+}
+
+/// Try to convert uadd.sat/sadd.sat(vector.reduce.add(mul(...)), c) to a
+/// saturating dot4 intrinsic. This combine starts at the root (saturating add)
+/// and looks at its operands.
+bool AMDGPUCodeGenPrepareImpl::visitSaturatingAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Intrinsic::ID IID = I.getIntrinsicID();
+ bool IsSigned = (IID == Intrinsic::sadd_sat);
+
+ // Look for vector.reduce.add as one of the operands
+ Value *ReduceOp = nullptr;
+ Value *Accum = nullptr;
+
+ for (int Swap = 0; Swap < 2; ++Swap) {
addmisol wrote:
sure👍
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2405,6 +2412,153 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Check if type is <4 x i8>.
+static bool isV4I8(Type *Ty) {
+ FixedVectorType *VTy = dyn_cast(Ty);
+ return VTy && VTy->getNumElements() == 4 &&
+ VTy->getElementType()->isIntegerTy(8);
+}
+
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ Value *Src0, *Src1;
+ if (!match(MulOp, m_Mul(m_Value(Src0), m_Value(Src1
+return false;
+
+ // Check that result type is <4 x i32>
+ FixedVectorType *MulTy = dyn_cast(MulOp->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ Value *ExtSrc0, *ExtSrc1;
+ if (!match(Src0, m_ZExtOrSExt(m_Value(ExtSrc0))) ||
!isV4I8(ExtSrc0->getType()))
+return false;
+ if (!match(Src1, m_ZExtOrSExt(m_Value(ExtSrc1))) ||
!isV4I8(ExtSrc1->getType()))
+return false;
+
+ // Both operands must have the same signedness
+ bool Signed0 = isa(Src0);
+ bool Signed1 = isa(Src1);
+ if (Signed0 != Signed1)
+return false;
+
+ A = ExtSrc0;
+ B = ExtSrc1;
+ IsSigned = Signed0;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Value *A = nullptr, *B = nullptr;
+ bool IsSigned = false;
+
+ if (!matchDot4Pattern(I.getArgOperand(0), A, B, IsSigned))
+return false;
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Non-saturating case: accumulator is 0, clamp is false
+ Value *Acc = ConstantInt::get(I32Ty, 0);
+ Value *Clamp = ConstantInt::getFalse(Ctx);
+
+ Intrinsic::ID DotIID =
+ IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+ Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Acc, Clamp});
+ Dot->takeName(&I);
+
+ I.replaceAllUsesWith(Dot);
+ DeadVals.push_back(&I);
+
+ return true;
+}
+
+/// Try to convert uadd.sat/sadd.sat(vector.reduce.add(mul(...)), c) to a
+/// saturating dot4 intrinsic. This combine starts at the root (saturating add)
+/// and looks at its operands.
+bool AMDGPUCodeGenPrepareImpl::visitSaturatingAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Intrinsic::ID IID = I.getIntrinsicID();
+ bool IsSigned = (IID == Intrinsic::sadd_sat);
+
+ // Look for vector.reduce.add as one of the operands
+ Value *ReduceOp = nullptr;
+ Value *Accum = nullptr;
+
+ for (int Swap = 0; Swap < 2; ++Swap) {
arsenm wrote:
You can use m_c_Intrinsic to handle the commutative cases here
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2405,6 +2412,162 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Check if type is <4 x i8>
+ auto IsV4I8 = [](Type *Ty) -> bool {
+auto *VTy = dyn_cast(Ty);
+return VTy && VTy->getNumElements() == 4 &&
+ VTy->getElementType()->isIntegerTy(8);
+ };
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto MatchExtend = [&IsV4I8](Value *V, Value *&Src, bool &Signed) -> bool {
+if (match(V, m_ZExt(m_Value(Src))) && IsV4I8(Src->getType())) {
arsenm wrote:
Can you use m_ZExtOrSExt instead?
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2405,6 +2412,162 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Check if type is <4 x i8>
+ auto IsV4I8 = [](Type *Ty) -> bool {
+auto *VTy = dyn_cast(Ty);
+return VTy && VTy->getNumElements() == 4 &&
+ VTy->getElementType()->isIntegerTy(8);
+ };
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto MatchExtend = [&IsV4I8](Value *V, Value *&Src, bool &Signed) -> bool {
+if (match(V, m_ZExt(m_Value(Src))) && IsV4I8(Src->getType())) {
+ Signed = false;
+ return true;
+}
+if (match(V, m_SExt(m_Value(Src))) && IsV4I8(Src->getType())) {
+ Signed = true;
+ return true;
+}
+return false;
+ };
+
+ bool Signed0 = false, Signed1 = false;
+ if (!MatchExtend(Src0, A, Signed0) || !MatchExtend(Src1, B, Signed1))
+return false;
+
+ // Both operands must have the same signedness
+ if (Signed0 != Signed1)
+return false;
+
+ IsSigned = Signed0;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Value *A = nullptr, *B = nullptr;
+ bool IsSigned = false;
+
+ if (!matchDot4Pattern(I.getArgOperand(0), A, B, IsSigned))
+return false;
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Non-saturating case: accumulator is 0, clamp is false
+ Value *Acc = ConstantInt::get(I32Ty, 0);
+ Value *Clamp = ConstantInt::getFalse(Ctx);
arsenm wrote:
Do we handle folding clamping patterns into the result somewhere else?
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2405,6 +2412,162 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
arsenm wrote:
Can move this into `match(MulOp, m_Mul(m_Value(Src0), m_Value(Src1))`
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -16942,6 +16959,85 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
return SDValue();
}
+// Try to fold saturating add with dot product pattern into dot instruction
+// with clamp. Matches patterns like:
+// uaddsat(a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3], c) -> v_dot4 clamp
+SDValue SITargetLowering::performSatAddCombine(SDNode *N,
+ DAGCombinerInfo &DCI) const {
+ SelectionDAG &DAG = DCI.DAG;
+ EVT VT = N->getValueType(0);
+ SDLoc SL(N);
+
+ // Only handle i32 saturating adds
+ if (VT != MVT::i32)
+return SDValue();
+
+ bool IsSigned = N->getOpcode() == ISD::SADDSAT;
+
+ // Check if we have dot instructions
+ if (!Subtarget->hasDot7Insts() ||
+ (!Subtarget->hasDot1Insts() && !Subtarget->hasDot8Insts()))
+return SDValue();
+
+ // First, check if one operand is already a dot intrinsic without clamp.
+ // If performAddCombine already created a dot instruction with clamp=0,
+ // we can fold the saturating add by regenerating with clamp=1.
+ SDValue DotOp = N->getOperand(0);
+ SDValue Accum = N->getOperand(1);
+
+ // Try both operand orders
+ for (int Swap = 0; Swap < 2; ++Swap) {
+if (DotOp.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
+ auto *IIDNode = dyn_cast(DotOp.getOperand(0));
+ if (!IIDNode) {
+std::swap(DotOp, Accum);
+continue;
+ }
+ unsigned IID = IIDNode->getZExtValue();
+ // Check for udot4/sdot4/udot2/sdot2 intrinsics
+ if ((IID == Intrinsic::amdgcn_udot4 && !IsSigned) ||
+ (IID == Intrinsic::amdgcn_sdot4 && IsSigned) ||
+ (IID == Intrinsic::amdgcn_udot2 && !IsSigned) ||
+ (IID == Intrinsic::amdgcn_sdot2 && IsSigned)) {
+// DotOp layout: [IID, Src0, Src1, Src2/Accum, Clamp]
+// Check if clamp is 0 and accumulator is 0
+SDValue OldAccum = DotOp.getOperand(3);
+SDValue OldClamp = DotOp.getOperand(4);
+
+// Check if old clamp is 0 (otherwise already saturating)
+auto *ClampConst = dyn_cast(OldClamp);
+if (!ClampConst || ClampConst->getZExtValue() != 0) {
+ std::swap(DotOp, Accum);
+ continue;
+}
+
+// Check if old accumulator is 0 (the pattern is dot(..., 0) + accum)
+auto *AccumConst = dyn_cast(OldAccum);
+if (!AccumConst || AccumConst->getZExtValue() != 0) {
+ std::swap(DotOp, Accum);
+ continue;
+}
+
+// Regenerate the dot with clamp=1 and the new accumulator
+SDValue NewIID = DAG.getTargetConstant(IID, SL, MVT::i64);
+SDValue Src0 = DotOp.getOperand(1);
+SDValue Src1 = DotOp.getOperand(2);
+
+auto NewDot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32,
addmisol wrote:
I have to remove some more auto's, will do it tomorrow..
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol edited https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -16946,6 +16946,73 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, return SDValue(); } +// Try to fold saturating add with dot product intrinsic into dot instruction +// with clamp. Matches patterns like: +// uaddsat(dot4(..., 0), c) -> dot4(..., c) clamp +// uaddsat(dot2(..., 0), c) -> dot2(..., c) clamp +SDValue SITargetLowering::performSatAddCombine(SDNode *N, addmisol wrote: Removed performSatAddCombine. The saturation patterns are already covered by: - visitSaturatingAdd in AMDGPUCodeGenPrepare for dot4 (handles uaddsat(reduce.add(mul(...)), c)) - TableGen patterns UDot2SatPat/SDot2SatPat in VOP3PInstructions.td for dot2 https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol deleted https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol edited https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -16946,6 +16946,73 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
return SDValue();
}
+// Try to fold saturating add with dot product intrinsic into dot instruction
+// with clamp. Matches patterns like:
+// uaddsat(dot4(..., 0), c) -> dot4(..., c) clamp
+// uaddsat(dot2(..., 0), c) -> dot2(..., c) clamp
+SDValue SITargetLowering::performSatAddCombine(SDNode *N,
addmisol wrote:
No, performSatAddCombine is currently not being called because
ISD::UADDSAT/ISD::SADDSAT are not registered with setTargetDAGCombine. Without
that registration, the DAG combiner never visits these nodes.
Should I:
1. Remove performSatAddCombine since the main paths are already covered, or
2. Add setTargetDAGCombine({..., ISD::UADDSAT, ISD::SADDSAT, ..}) to make it
functional?
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2366,6 +2373,164 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+if (auto *ZExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(ZExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
arsenm wrote:
This could probably be a bit more tolerant of smaller bitwidths. Could
computeNumSignBits, but best left for later
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -16946,6 +16946,73 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, return SDValue(); } +// Try to fold saturating add with dot product intrinsic into dot instruction +// with clamp. Matches patterns like: +// uaddsat(dot4(..., 0), c) -> dot4(..., c) clamp +// uaddsat(dot2(..., 0), c) -> dot2(..., c) clamp +SDValue SITargetLowering::performSatAddCombine(SDNode *N, arsenm wrote: Is this version still tested? https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2366,6 +2373,164 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+if (auto *ZExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(ZExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = ZExt->getOperand(0);
+Signed = false;
+return true;
+ }
+} else if (auto *SExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(SExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
arsenm wrote:
Can use common isV4I8 function between the two cases
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2366,6 +2373,164 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+if (auto *ZExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(ZExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = ZExt->getOperand(0);
+Signed = false;
+return true;
+ }
+} else if (auto *SExt = dyn_cast(V)) {
arsenm wrote:
Could shrink this using match()
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2366,6 +2373,164 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+if (auto *ZExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(ZExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = ZExt->getOperand(0);
+Signed = false;
+return true;
+ }
+} else if (auto *SExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(SExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = SExt->getOperand(0);
+Signed = true;
+return true;
+ }
+}
+return false;
+ };
+
+ bool Signed0 = false, Signed1 = false;
+ if (!matchExtend(Src0, A, Signed0) || !matchExtend(Src1, B, Signed1))
+return false;
+
+ // Both operands must have the same signedness
+ if (Signed0 != Signed1)
+return false;
+
+ IsSigned = Signed0;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Value *A = nullptr, *B = nullptr;
+ bool IsSigned = false;
+
+ if (!matchDot4Pattern(I.getArgOperand(0), A, B, IsSigned))
+return false;
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Non-saturating case: accumulator is 0, clamp is false
+ Value *Acc = ConstantInt::get(I32Ty, 0);
+ Value *Clamp = ConstantInt::getFalse(Ctx);
+
+ Intrinsic::ID DotIID =
+ IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+ Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Acc, Clamp},
+ nullptr, I.getName());
arsenm wrote:
Use takeName after instead of passing in the name
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2366,6 +2373,164 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+if (auto *ZExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(ZExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = ZExt->getOperand(0);
+Signed = false;
+return true;
+ }
+} else if (auto *SExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(SExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = SExt->getOperand(0);
+Signed = true;
+return true;
+ }
+}
+return false;
+ };
+
+ bool Signed0 = false, Signed1 = false;
+ if (!matchExtend(Src0, A, Signed0) || !matchExtend(Src1, B, Signed1))
+return false;
+
+ // Both operands must have the same signedness
+ if (Signed0 != Signed1)
+return false;
+
+ IsSigned = Signed0;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Value *A = nullptr, *B = nullptr;
+ bool IsSigned = false;
+
+ if (!matchDot4Pattern(I.getArgOperand(0), A, B, IsSigned))
+return false;
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Non-saturating case: accumulator is 0, clamp is false
+ Value *Acc = ConstantInt::get(I32Ty, 0);
+ Value *Clamp = ConstantInt::getFalse(Ctx);
+
+ Intrinsic::ID DotIID =
+ IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+ Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Acc, Clamp},
+ nullptr, I.getName());
+
+ I.replaceAllUsesWith(Dot);
+ DeadVals.push_back(&I);
+
+ return true;
+}
+
+/// Try to convert uadd.sat/sadd.sat(vector.reduce.add(mul(...)), c) to a
+/// saturating dot4 intrinsic. This combine starts at the root (saturating add)
+/// and looks at its operands.
+bool AMDGPUCodeGenPrepareImpl::visitSaturatingAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ Intrinsic::ID IID = I.getIntrinsicID();
+ bool IsSigned = (IID == Intrinsic::sadd_sat);
+
+ // Look for vector.reduce.add as one of the operands
+ Value *ReduceOp = nullptr;
+ Value *Accum = nullptr;
+
+ for (int Swap = 0; Swap < 2; ++Swap) {
+Value *Op0 = I.getArgOperand(Swap);
+Value *Op1 = I.getArgOperand(1 - Swap);
+
+if (auto *ReduceInst = dyn_cast(Op0)) {
+ if (ReduceInst->getIntrinsicID() == Intrinsic::vector_reduce_add) {
+ReduceOp = Op0;
+Accum = Op1;
+break;
+ }
+}
+ }
+
+ if (!ReduceOp)
+return false;
+
+ auto *ReduceInst = cast(ReduceOp);
+
+ Value *A = nullptr, *B = nullptr;
+ bool PatternSigned = false;
+
+ if (!matchDot4Pattern(ReduceInst->getArgOperand(0), A, B, PatternSigned))
+return false;
+
+ // Signedness of the pattern must match the saturating add type
+ if (PatternSigned != IsSigned)
+return false;
+
+ LLVMContext &Ctx = I.getContext();
+ Type *I32Ty = Type::getInt32Ty(Ctx);
+ IRBuilder<> Builder(&I);
+
+ // Bitcast <4 x i8> to i32
+ Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+ Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+ // Saturating case: use the accumulator and set clamp to true
+ Value *Clamp = ConstantInt::getTrue(Ctx);
+
+ Intrinsic::ID DotIID =
+ IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+ Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Accum, Clamp},
+ nullptr, I.getName());
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
addmisol wrote: @arsenm, can you please tell if i need to update anything else.. https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -2336,6 +2343,177 @@ bool
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
return tryReplaceWithWorkitemId(I, Wave);
}
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+ bool &IsSigned) {
+ auto *Mul = dyn_cast(MulOp);
+ if (!Mul || Mul->getOpcode() != Instruction::Mul)
+return false;
+
+ // Check that result type is <4 x i32>
+ auto *MulTy = dyn_cast(Mul->getType());
+ if (!MulTy || MulTy->getNumElements() != 4 ||
+ !MulTy->getElementType()->isIntegerTy(32))
+return false;
+
+ Value *Src0 = Mul->getOperand(0);
+ Value *Src1 = Mul->getOperand(1);
+
+ // Match zext <4 x i8> or sext <4 x i8>
+ auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+if (auto *ZExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(ZExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = ZExt->getOperand(0);
+Signed = false;
+return true;
+ }
+} else if (auto *SExt = dyn_cast(V)) {
+ auto *SrcTy = dyn_cast(SExt->getSrcTy());
+ if (SrcTy && SrcTy->getNumElements() == 4 &&
+ SrcTy->getElementType()->isIntegerTy(8)) {
+Src = SExt->getOperand(0);
+Signed = true;
+return true;
+ }
+}
+return false;
+ };
+
+ bool Signed0 = false, Signed1 = false;
+ if (!matchExtend(Src0, A, Signed0) || !matchExtend(Src1, B, Signed1))
+return false;
+
+ // Both operands must have the same signedness
+ if (Signed0 != Signed1)
+return false;
+
+ IsSigned = Signed0;
+ return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only). The saturating
+/// case is handled by visitSaturatingAdd which starts at the root.
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+ // Check if we have dot4 instructions available
+ if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+return false;
+
+ // Skip if this reduce is used by a saturating add - that case will be
+ // handled by visitSaturatingAdd starting from the root instruction.
+ if (I.hasOneUse()) {
+if (auto *User = dyn_cast(*I.user_begin())) {
addmisol wrote:
yeah, this check was redundant. The pass iterates bottom-up, so
visitSaturatingAdd runs first at the root
instruction before visitVectorReduceAdd would process its operand.
Additionally, even if the non-saturating dot4 were generated first,
performSatAddCombine in the DAG combiner would fold uaddsat(dot4(..., 0), c) to
dot4(..., c, clamp=1) anyway...
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -731,13 +731,49 @@ defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f def : UDot2Pat; def : SDot2Pat; +// Saturating unsigned dot2 pattern: uaddsat(a[0]*b[0] + a[1]*b[1], c) +class UDot2SatPat : GCNPat < addmisol wrote: Yes, these patterns are still reachable. They match a different input form than performSatAddCombine: - performSatAddCombine: Matches uaddsat(INTRINSIC_WO_CHAIN(amdgcn_udot2, ..., 0), accum) — patterns where the dot intrinsic wass already formed at IR level from <2 x i16> vectors. - UDot2SatPat/SDot2SatPat: Match scalar decomposed patterns like uaddsat(add(mul_u24(srl $src0, 16), ..), ..) — these come from scalar i32 code with packed i16 values extracted via shifts/masks, which bypasses the IR-level intrinsic formation. Added scalar_udot2_sat and scalar_sdot2_sat tests to verify these patterns are exercised.. https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -0,0 +1,687 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=amdgcn -mcpu=gfx906 < %s | FileCheck
--check-prefixes=GFX9-DL %s
+; RUN: llc -mtriple=amdgcn -mcpu=gfx1011 < %s | FileCheck
--check-prefixes=GFX10-DL %s
+; RUN: llc -mtriple=amdgcn -mcpu=gfx950 < %s | FileCheck
--check-prefixes=GFX950 %s
+
+; Test dot2 and dot4 patterns with saturating add (clamp) and without
+
+;--
+; DOT2 SATURATING TESTS
+;--
+
+; Unsigned dot2 with saturation: uaddsat(a[0]*b[0] + a[1]*b[1], c)
+define i32 @udot2_sat(<2 x i16> %a, <2 x i16> %b, i32 %c) {
+; GFX9-DL-LABEL: udot2_sat:
+; GFX9-DL: ; %bb.0: ; %entry
+; GFX9-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX9-DL-NEXT:v_dot2_u32_u16 v0, v1, v0, v2 clamp
+; GFX9-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX10-DL-LABEL: udot2_sat:
+; GFX10-DL: ; %bb.0: ; %entry
+; GFX10-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX10-DL-NEXT:v_dot2_u32_u16 v0, v1, v0, v2 clamp
+; GFX10-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX950-LABEL: udot2_sat:
+; GFX950: ; %bb.0: ; %entry
+; GFX950-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-NEXT:v_dot2_u32_u16 v0, v1, v0, v2 clamp
+; GFX950-NEXT:s_setpc_b64 s[30:31]
+entry:
+ %conv.i = zext <2 x i16> %a to <2 x i32>
+ %conv6.i = zext <2 x i16> %b to <2 x i32>
+ %mul.i = mul <2 x i32> %conv6.i, %conv.i
+ %0 = extractelement <2 x i32> %mul.i, i64 0
+ %1 = extractelement <2 x i32> %mul.i, i64 1
+ %add.i = add i32 %0, %1
+ %cond.i.i = tail call i32 @llvm.uadd.sat.i32(i32 %add.i, i32 %c)
+ ret i32 %cond.i.i
+}
+
+; Signed dot2 with saturation: saddsat(a[0]*b[0] + a[1]*b[1], c)
+define i32 @sdot2_sat(<2 x i16> %a, <2 x i16> %b, i32 %c) {
+; GFX9-DL-LABEL: sdot2_sat:
+; GFX9-DL: ; %bb.0: ; %entry
+; GFX9-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX9-DL-NEXT:v_dot2_i32_i16 v0, v1, v0, v2 clamp
+; GFX9-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX10-DL-LABEL: sdot2_sat:
+; GFX10-DL: ; %bb.0: ; %entry
+; GFX10-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX10-DL-NEXT:v_dot2_i32_i16 v0, v1, v0, v2 clamp
+; GFX10-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX950-LABEL: sdot2_sat:
+; GFX950: ; %bb.0: ; %entry
+; GFX950-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-NEXT:v_dot2_i32_i16 v0, v1, v0, v2 clamp
+; GFX950-NEXT:s_setpc_b64 s[30:31]
+entry:
+ %conv.i = sext <2 x i16> %a to <2 x i32>
+ %conv6.i = sext <2 x i16> %b to <2 x i32>
+ %mul.i = mul <2 x i32> %conv6.i, %conv.i
+ %0 = extractelement <2 x i32> %mul.i, i64 0
+ %1 = extractelement <2 x i32> %mul.i, i64 1
+ %add.i = add i32 %0, %1
+ %cond1.i.i = tail call i32 @llvm.sadd.sat.i32(i32 %add.i, i32 %c)
+ ret i32 %cond1.i.i
+}
+
+;--
+; DOT2 NON-SATURATING TESTS
+;--
+
+; Unsigned dot2 without saturation
+define i32 @udot2_unsat(<2 x i16> %a, <2 x i16> %b, i32 %c) {
+; GFX9-DL-LABEL: udot2_unsat:
+; GFX9-DL: ; %bb.0: ; %entry
+; GFX9-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX9-DL-NEXT:v_dot2_u32_u16 v0, v1, v0, v2
+; GFX9-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX10-DL-LABEL: udot2_unsat:
+; GFX10-DL: ; %bb.0: ; %entry
+; GFX10-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX10-DL-NEXT:v_dot2_u32_u16 v0, v1, v0, v2
+; GFX10-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX950-LABEL: udot2_unsat:
+; GFX950: ; %bb.0: ; %entry
+; GFX950-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-NEXT:v_dot2_u32_u16 v0, v1, v0, v2
+; GFX950-NEXT:s_setpc_b64 s[30:31]
+entry:
+ %conv.i = zext <2 x i16> %a to <2 x i32>
+ %conv6.i = zext <2 x i16> %b to <2 x i32>
+ %mul.i = mul <2 x i32> %conv6.i, %conv.i
+ %0 = extractelement <2 x i32> %mul.i, i64 0
+ %1 = extractelement <2 x i32> %mul.i, i64 1
+ %add.i = add i32 %1, %c
+ %add8.i = add i32 %add.i, %0
+ ret i32 %add8.i
+}
+
+; Signed dot2 without saturation
+define i32 @sdot2_unsat(<2 x i16> %a, <2 x i16> %b, i32 %c) {
+; GFX9-DL-LABEL: sdot2_unsat:
+; GFX9-DL: ; %bb.0: ; %entry
+; GFX9-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX9-DL-NEXT:v_dot2_i32_i16 v0, v1, v0, v2
+; GFX9-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX10-DL-LABEL: sdot2_unsat:
+; GFX10-DL: ; %bb.0: ; %entry
+; GFX10-DL-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX10-DL-NEXT:v_dot2_i32_i16 v0, v1, v0, v2
+; GFX10-DL-NEXT:s_setpc_b64 s[30:31]
+;
+; GFX950-LABEL: sdot2_unsat:
+; GFX950: ; %bb.0: ; %entry
+; GFX950-NEXT:s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX950-NEXT:v_dot2_i32_i16 v0, v1, v0, v2
+; GFX950-NEXT:s_setpc_b64 s[30:31]
+entry:
+ %conv.
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
@@ -5661,13 +5661,6 @@ static unsigned getDPPOpcForWaveReduction(unsigned Opc, return AMDGPU::V_OR_B32_dpp; case AMDGPU::S_XOR_B32: return AMDGPU::V_XOR_B32_dpp; - case AMDGPU::V_ADD_F32_e64: addmisol wrote: yeah sure, need a rebase https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
https://github.com/addmisol ready_for_review https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
addmisol wrote: for dot4 with <4 x i8> args : AMDGPU ABI unpacks <4 x i8> vectors into 4 separate i32 registers when passed as function arguments. By the time the DAG combiner sees the code, it looks like: v0 = byte0 (as i32) v1 = byte1 (as i32) v2 = byte2 (as i32) v3 = byte3 (as i32) The packed byte pattern needed for v_dot4_u32_u8 is lost. but this will work When bytes are loaded from memory as packed i32, dot4 works correctly: v_dot4_u32_u8 v1, v1, v2, s0; non-saturating v_dot4_u32_u8 v1, v1, v2, s0 clamp ; saturating so, - All dot2 patterns are fixed - dot4 patterns work when bytes come from memory (packed) - dot4 with <4 x i8> function arguments is an ABI limitation https://github.com/llvm/llvm-project/pull/187945 ___ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
[clang] [llvm] [AMDGPU] Add dot product patterns with saturating add (clamp) (PR #187945)
github-actions[bot] wrote:
:warning: C/C++ code formatter, clang-format found issues in your code.
:warning:
You can test this locally with the following command:
``bash
git-clang-format --diff origin/main HEAD --extensions h,cpp --
llvm/lib/Target/AMDGPU/SIISelLowering.cpp
llvm/lib/Target/AMDGPU/SIISelLowering.h --diff_from_common_commit
``
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
View the diff from clang-format here.
``diff
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 1fb7c48e3..36f323d2a 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -16953,10 +16953,10 @@ SDValue SITargetLowering::performSatAddCombine(SDNode
*N,
if (!Src1)
break;
-auto IterIsSigned = checkDot4MulSignedness(
-TempNode.getOperand(MulIdx), *Src0, *Src1,
-TempNode.getOperand(MulIdx).getOperand(0),
-TempNode.getOperand(MulIdx).getOperand(1), DAG);
+auto IterIsSigned =
+checkDot4MulSignedness(TempNode.getOperand(MulIdx), *Src0, *Src1,
+ TempNode.getOperand(MulIdx).getOperand(0),
+ TempNode.getOperand(MulIdx).getOperand(1), DAG);
if (!IterIsSigned)
break;
if (!MulIsSigned)
@@ -16970,12 +16970,10 @@ SDValue SITargetLowering::performSatAddCombine(SDNode
*N,
// add (mul24, mul24).
if (I == 2 && isMul(TempNode.getOperand(AddIdx))) {
Src2s.push_back(TempNode.getOperand(AddIdx));
- auto Src0 =
- handleMulOperand(TempNode.getOperand(AddIdx).getOperand(0));
+ auto Src0 = handleMulOperand(TempNode.getOperand(AddIdx).getOperand(0));
if (!Src0)
break;
- auto Src1 =
- handleMulOperand(TempNode.getOperand(AddIdx).getOperand(1));
+ auto Src1 = handleMulOperand(TempNode.getOperand(AddIdx).getOperand(1));
if (!Src1)
break;
auto IterIsSigned = checkDot4MulSignedness(
@@ -17040,13 +17038,13 @@ SDValue SITargetLowering::performSatAddCombine(SDNode
*N,
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
auto *SecondElt = Src1s.begin();
- auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
-SecondElt->DWordOffset);
+ auto SecondEltOp =
+ getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
SecondElt->DWordOffset);
- Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
- MVT::getIntegerVT(32));
- Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
- MVT::getIntegerVT(32));
+ Src0 =
+ DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL, MVT::getIntegerVT(32));
+ Src1 =
+ DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
MVT::getIntegerVT(32));
}
}
@@ -17059,12 +17057,12 @@ SDValue SITargetLowering::performSatAddCombine(SDNode
*N,
SDValue Src2 = DAG.getExtOrTrunc(IsSigned, Accum, SL, MVT::i32);
SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
-: Intrinsic::amdgcn_udot4,
+ : Intrinsic::amdgcn_udot4,
SL, MVT::i64);
// Generate dot4 with clamp=1 for saturation
- auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
- Src1, Src2, DAG.getTargetConstant(1, SL, MVT::i1));
+ auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
Src1,
+ Src2, DAG.getTargetConstant(1, SL, MVT::i1));
return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
}
``
https://github.com/llvm/llvm-project/pull/187945
___
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
