llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Pierre van Houtryve (Pierre-vh)

<details>
<summary>Changes</summary>

There's quite a few opcodes that do not care about the exact AS of the pointer, 
just its size.
Adding generic types for these will help reduce duplication in the rule 
definitions.

I also moved the usual B types to use the new `isAnyPtr` helper I added to make 
sure they're supersets of the `Ptr` cases

---
Full diff: https://github.com/llvm/llvm-project/pull/142602.diff


3 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp (+33-9) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp (+25-4) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h (+19) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp 
b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
index 12af7233ffad6..26aa3cf36c87a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
@@ -605,17 +605,23 @@ LLT 
RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB32:
   case UniInVgprB32:
     if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
-        Ty == LLT::pointer(6, 32))
+        isAnyPtr(Ty, 32))
       return Ty;
     return LLT();
+  case SgprPtr32:
+  case VgprPtr32:
+    return isAnyPtr(Ty, 32) ? Ty : LLT();
+  case SgprPtr64:
+  case VgprPtr64:
+    return isAnyPtr(Ty, 64) ? Ty : LLT();
+  case SgprPtr128:
+  case VgprPtr128:
+    return isAnyPtr(Ty, 128) ? Ty : LLT();
   case SgprB64:
   case VgprB64:
   case UniInVgprB64:
     if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
-        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
-        (Ty.isPointer() && Ty.getAddressSpace() > 
AMDGPUAS::MAX_AMDGPU_ADDRESS))
+        Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
       return Ty;
     return LLT();
   case SgprB96:
@@ -629,7 +635,7 @@ LLT 
RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB128:
   case UniInVgprB128:
     if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
-        Ty == LLT::fixed_vector(2, 64))
+        Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
       return Ty;
     return LLT();
   case SgprB256:
@@ -668,6 +674,9 @@ 
RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case SgprP5:
   case SgprP6:
   case SgprP8:
+  case SgprPtr32:
+  case SgprPtr64:
+  case SgprPtr128:
   case SgprV2S16:
   case SgprV2S32:
   case SgprV4S32:
@@ -705,6 +714,9 @@ 
RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case VgprP5:
   case VgprP6:
   case VgprP8:
+  case VgprPtr32:
+  case VgprPtr64:
+  case VgprPtr128:
   case VgprV2S16:
   case VgprV2S32:
   case VgprV4S32:
@@ -778,12 +790,18 @@ void RegBankLegalizeHelper::applyMappingDst(
     case SgprB128:
     case SgprB256:
     case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128:
     case VgprB32:
     case VgprB64:
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
       assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
       break;
@@ -892,7 +910,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case SgprB96:
     case SgprB128:
     case SgprB256:
-    case SgprB512: {
+    case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       assert(RB == getRegBankFromID(MethodIDs[i]));
       break;
@@ -926,7 +947,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       if (RB != VgprRB) {
         auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp 
b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
index 08a35b9794344..b6260076731ba 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
@@ -26,6 +26,10 @@
 using namespace llvm;
 using namespace AMDGPU;
 
+bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
+  return Ty.isPointer() && Ty.getSizeInBits() == Width;
+}
+
 RegBankLLTMapping::RegBankLLTMapping(
     std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
     std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
@@ -68,6 +72,12 @@ bool matchUniformityAndLLT(Register Reg, 
UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32);
   case P8:
     return MRI.getType(Reg) == LLT::pointer(8, 128);
+  case Ptr32:
+    return isAnyPtr(MRI.getType(Reg), 32);
+  case Ptr64:
+    return isAnyPtr(MRI.getType(Reg), 64);
+  case Ptr128:
+    return isAnyPtr(MRI.getType(Reg), 128);
   case V2S32:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
   case V4S32:
@@ -110,6 +120,12 @@ bool matchUniformityAndLLT(Register Reg, 
UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isUniform(Reg);
   case UniP8:
     return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isUniform(Reg);
+  case UniPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
+  case UniPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
+  case UniPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
   case UniV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
   case UniB32:
@@ -150,6 +166,12 @@ bool matchUniformityAndLLT(Register Reg, 
UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isDivergent(Reg);
   case DivP8:
     return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isDivergent(Reg);
+  case DivPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
+  case DivPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
+  case DivPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
   case DivV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && 
MUI.isDivergent(Reg);
   case DivB32:
@@ -223,15 +245,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
 
 UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
   if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 32))
+      isAnyPtr(Ty, 32))
     return B32;
   if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-      Ty == LLT::fixed_vector(4, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 64))
+      Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
     return B64;
   if (Ty == LLT::fixed_vector(3, 32))
     return B96;
-  if (Ty == LLT::fixed_vector(4, 32))
+  if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
     return B128;
   return _;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h 
b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
index 14be873b6ce19..1d429f711fbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
@@ -15,6 +15,7 @@
 
 namespace llvm {
 
+class LLT;
 class MachineRegisterInfo;
 class MachineInstr;
 class GCNSubtarget;
@@ -26,6 +27,9 @@ using MachineUniformityInfo = 
GenericUniformityInfo<MachineSSAContext>;
 
 namespace AMDGPU {
 
+/// \returns true if \p Ty is a pointer type with size \p Width.
+bool isAnyPtr(LLT Ty, unsigned Width);
+
 // IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
 // or more IDs and each represents a check for 'uniform or divergent' + LLT or
 // just LLT on register operand.
@@ -62,6 +66,9 @@ enum UniformityLLTOpPredicateID {
   P5,
   P6,
   P8,
+  Ptr32,
+  Ptr64,
+  Ptr128,
 
   UniP0,
   UniP1,
@@ -71,6 +78,9 @@ enum UniformityLLTOpPredicateID {
   UniP5,
   UniP6,
   UniP8,
+  UniPtr32,
+  UniPtr64,
+  UniPtr128,
 
   DivP0,
   DivP1,
@@ -80,6 +90,9 @@ enum UniformityLLTOpPredicateID {
   DivP5,
   DivP6,
   DivP8,
+  DivPtr32,
+  DivPtr64,
+  DivPtr128,
 
   // vectors
   V2S16,
@@ -138,6 +151,9 @@ enum RegBankLLTMappingApplyID {
   SgprP5,
   SgprP6,
   SgprP8,
+  SgprPtr32,
+  SgprPtr64,
+  SgprPtr128,
   SgprV2S16,
   SgprV4S32,
   SgprV2S32,
@@ -161,6 +177,9 @@ enum RegBankLLTMappingApplyID {
   VgprP5,
   VgprP6,
   VgprP8,
+  VgprPtr32,
+  VgprPtr64,
+  VgprPtr128,
   VgprV2S16,
   VgprV2S32,
   VgprB32,

``````````

</details>


https://github.com/llvm/llvm-project/pull/142602
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to