llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-amdgpu Author: Pierre van Houtryve (Pierre-vh) <details> <summary>Changes</summary> There's quite a few opcodes that do not care about the exact AS of the pointer, just its size. Adding generic types for these will help reduce duplication in the rule definitions. I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases --- Full diff: https://github.com/llvm/llvm-project/pull/142602.diff 3 Files Affected: - (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp (+33-9) - (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp (+25-4) - (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h (+19) ``````````diff diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp index 12af7233ffad6..26aa3cf36c87a 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp @@ -605,17 +605,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) { case VgprB32: case UniInVgprB32: if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) || - Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) || - Ty == LLT::pointer(6, 32)) + isAnyPtr(Ty, 32)) return Ty; return LLT(); + case SgprPtr32: + case VgprPtr32: + return isAnyPtr(Ty, 32) ? Ty : LLT(); + case SgprPtr64: + case VgprPtr64: + return isAnyPtr(Ty, 64) ? Ty : LLT(); + case SgprPtr128: + case VgprPtr128: + return isAnyPtr(Ty, 128) ? Ty : LLT(); case SgprB64: case VgprB64: case UniInVgprB64: if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) || - Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) || - Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) || - (Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS)) + Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64)) return Ty; return LLT(); case SgprB96: @@ -629,7 +635,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) { case VgprB128: case UniInVgprB128: if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) || - Ty == LLT::fixed_vector(2, 64)) + Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128)) return Ty; return LLT(); case SgprB256: @@ -668,6 +674,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) { case SgprP5: case SgprP6: case SgprP8: + case SgprPtr32: + case SgprPtr64: + case SgprPtr128: case SgprV2S16: case SgprV2S32: case SgprV4S32: @@ -705,6 +714,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) { case VgprP5: case VgprP6: case VgprP8: + case VgprPtr32: + case VgprPtr64: + case VgprPtr128: case VgprV2S16: case VgprV2S32: case VgprV4S32: @@ -778,12 +790,18 @@ void RegBankLegalizeHelper::applyMappingDst( case SgprB128: case SgprB256: case SgprB512: + case SgprPtr32: + case SgprPtr64: + case SgprPtr128: case VgprB32: case VgprB64: case VgprB96: case VgprB128: case VgprB256: - case VgprB512: { + case VgprB512: + case VgprPtr32: + case VgprPtr64: + case VgprPtr128: { assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty)); assert(RB == getRegBankFromID(MethodIDs[OpIdx])); break; @@ -892,7 +910,10 @@ void RegBankLegalizeHelper::applyMappingSrc( case SgprB96: case SgprB128: case SgprB256: - case SgprB512: { + case SgprB512: + case SgprPtr32: + case SgprPtr64: + case SgprPtr128: { assert(Ty == getBTyFromID(MethodIDs[i], Ty)); assert(RB == getRegBankFromID(MethodIDs[i])); break; @@ -926,7 +947,10 @@ void RegBankLegalizeHelper::applyMappingSrc( case VgprB96: case VgprB128: case VgprB256: - case VgprB512: { + case VgprB512: + case VgprPtr32: + case VgprPtr64: + case VgprPtr128: { assert(Ty == getBTyFromID(MethodIDs[i], Ty)); if (RB != VgprRB) { auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index 08a35b9794344..b6260076731ba 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -26,6 +26,10 @@ using namespace llvm; using namespace AMDGPU; +bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) { + return Ty.isPointer() && Ty.getSizeInBits() == Width; +} + RegBankLLTMapping::RegBankLLTMapping( std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList, std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList, @@ -68,6 +72,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, return MRI.getType(Reg) == LLT::pointer(6, 32); case P8: return MRI.getType(Reg) == LLT::pointer(8, 128); + case Ptr32: + return isAnyPtr(MRI.getType(Reg), 32); + case Ptr64: + return isAnyPtr(MRI.getType(Reg), 64); + case Ptr128: + return isAnyPtr(MRI.getType(Reg), 128); case V2S32: return MRI.getType(Reg) == LLT::fixed_vector(2, 32); case V4S32: @@ -110,6 +120,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isUniform(Reg); case UniP8: return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isUniform(Reg); + case UniPtr32: + return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg); + case UniPtr64: + return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg); + case UniPtr128: + return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg); case UniV2S16: return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg); case UniB32: @@ -150,6 +166,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isDivergent(Reg); case DivP8: return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isDivergent(Reg); + case DivPtr32: + return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg); + case DivPtr64: + return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg); + case DivPtr128: + return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg); case DivV2S16: return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg); case DivB32: @@ -223,15 +245,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) { UniformityLLTOpPredicateID LLTToBId(LLT Ty) { if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) || - (Ty.isPointer() && Ty.getSizeInBits() == 32)) + isAnyPtr(Ty, 32)) return B32; if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) || - Ty == LLT::fixed_vector(4, 16) || - (Ty.isPointer() && Ty.getSizeInBits() == 64)) + Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64)) return B64; if (Ty == LLT::fixed_vector(3, 32)) return B96; - if (Ty == LLT::fixed_vector(4, 32)) + if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128)) return B128; return _; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h index 14be873b6ce19..1d429f711fbf6 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h @@ -15,6 +15,7 @@ namespace llvm { +class LLT; class MachineRegisterInfo; class MachineInstr; class GCNSubtarget; @@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>; namespace AMDGPU { +/// \returns true if \p Ty is a pointer type with size \p Width. +bool isAnyPtr(LLT Ty, unsigned Width); + // IDs used to build predicate for RegBankLegalizeRule. Predicate can have one // or more IDs and each represents a check for 'uniform or divergent' + LLT or // just LLT on register operand. @@ -62,6 +66,9 @@ enum UniformityLLTOpPredicateID { P5, P6, P8, + Ptr32, + Ptr64, + Ptr128, UniP0, UniP1, @@ -71,6 +78,9 @@ enum UniformityLLTOpPredicateID { UniP5, UniP6, UniP8, + UniPtr32, + UniPtr64, + UniPtr128, DivP0, DivP1, @@ -80,6 +90,9 @@ enum UniformityLLTOpPredicateID { DivP5, DivP6, DivP8, + DivPtr32, + DivPtr64, + DivPtr128, // vectors V2S16, @@ -138,6 +151,9 @@ enum RegBankLLTMappingApplyID { SgprP5, SgprP6, SgprP8, + SgprPtr32, + SgprPtr64, + SgprPtr128, SgprV2S16, SgprV4S32, SgprV2S32, @@ -161,6 +177,9 @@ enum RegBankLLTMappingApplyID { VgprP5, VgprP6, VgprP8, + VgprPtr32, + VgprPtr64, + VgprPtr128, VgprV2S16, VgprV2S32, VgprB32, `````````` </details> https://github.com/llvm/llvm-project/pull/142602 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits