https://github.com/arsenm created 
https://github.com/llvm/llvm-project/pull/140878

Recognize a reg_sequence with 32-bit elements that produce a 64-bit
splat value. This enables folding f64 constants into mfma operands

>From b68d880b7872cd90d3aa79419800cfc505305b76 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <matthew.arsena...@amd.com>
Date: Mon, 19 May 2025 21:51:06 +0200
Subject: [PATCH] AMDGPU: Handle folding vector splats of inline split f64
 inline immediates

Recognize a reg_sequence with 32-bit elements that produce a 64-bit
splat value. This enables folding f64 constants into mfma operands
---
 llvm/lib/Target/AMDGPU/SIFoldOperands.cpp     | 103 ++++++++++++------
 .../CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll |  41 +------
 2 files changed, 76 insertions(+), 68 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp 
b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
index eb7fb94e25f5c..70e3974bb22b4 100644
--- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
+++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
   getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
                 Register UseReg) const;
 
-  std::pair<MachineOperand *, const TargetRegisterClass *>
+  std::pair<int64_t, const TargetRegisterClass *>
   isRegSeqSplat(MachineInstr &RegSeg) const;
 
-  MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
-                                     MachineOperand *SplatVal,
-                                     const TargetRegisterClass *SplatRC) const;
+  bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
+                          int64_t SplatVal,
+                          const TargetRegisterClass *SplatRC) const;
 
   bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI,
                       unsigned UseOpIdx,
@@ -967,15 +967,15 @@ const TargetRegisterClass 
*SIFoldOperandsImpl::getRegSeqInit(
   return getRegSeqInit(*Def, Defs);
 }
 
-std::pair<MachineOperand *, const TargetRegisterClass *>
+std::pair<int64_t, const TargetRegisterClass *>
 SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
   SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
   const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
   if (!SrcRC)
     return {};
 
-  // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
-  // every other other element is 0 for 64-bit immediates)
+  bool TryToMatchSplat64 = false;
+
   int64_t Imm;
   for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
     const MachineOperand *Op = Defs[I].first;
@@ -987,38 +987,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) 
const {
       Imm = SubImm;
       continue;
     }
-    if (Imm != SubImm)
+
+    if (Imm != SubImm) {
+      if (I == 1 && (E & 1) == 0) {
+        // If we have an even number of inputs, there's a chance this is a
+        // 64-bit element splat broken into 32-bit pieces.
+        TryToMatchSplat64 = true;
+        break;
+      }
+
       return {}; // Can only fold splat constants
+    }
+  }
+
+  if (!TryToMatchSplat64)
+    return {Defs[0].first->getImm(), SrcRC};
+
+  // Fallback to recognizing 64-bit splats broken into 32-bit pieces
+  // (i.e. recognize every other other element is 0 for 64-bit immediates)
+  int64_t SplatVal64;
+  for (unsigned I = 0, E = Defs.size(); I != E; I += 2) {
+    const MachineOperand *Op0 = Defs[I].first;
+    const MachineOperand *Op1 = Defs[I + 1].first;
+
+    if (!Op0->isImm() || !Op1->isImm())
+      return {};
+
+    unsigned SubReg0 = Defs[I].second;
+    unsigned SubReg1 = Defs[I + 1].second;
+
+    // Assume we're going to generally encounter reg_sequences with sorted
+    // subreg indexes, so reject any that aren't consecutive.
+    if (TRI->getChannelFromSubReg(SubReg0) + 1 !=
+        TRI->getChannelFromSubReg(SubReg1))
+      return {};
+
+    int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm());
+    if (I == 0)
+      SplatVal64 = MergedVal;
+    else if (SplatVal64 != MergedVal)
+      return {};
   }
 
-  return {Defs[0].first, SrcRC};
+  const TargetRegisterClass *RC64 = TRI->getSubRegisterClass(
+      MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1);
+
+  return {SplatVal64, RC64};
 }
 
-MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
-    MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
+bool SIFoldOperandsImpl::tryFoldRegSeqSplat(
+    MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
     const TargetRegisterClass *SplatRC) const {
   const MCInstrDesc &Desc = UseMI->getDesc();
   if (UseOpIdx >= Desc.getNumOperands())
-    return nullptr;
+    return false;
 
   // Filter out unhandled pseudos.
   if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
-    return nullptr;
+    return false;
 
   int16_t RCID = Desc.operands()[UseOpIdx].RegClass;
   if (RCID == -1)
-    return nullptr;
+    return false;
+
+  const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
 
   // Special case 0/-1, since when interpreted as a 64-bit element both halves
-  // have the same bits. Effectively this code does not handle 64-bit element
-  // operands correctly, as the incoming 64-bit constants are already split 
into
-  // 32-bit sequence elements.
-  //
-  // TODO: We should try to figure out how to interpret the reg_sequence as a
-  // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
-  // constants.
-  if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) {
-    const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
+  // have the same bits. These are the only cases where a splat has the same
+  // interpretation for 32-bit and 64-bit splats.
+  if (SplatVal != 0 && SplatVal != -1) {
     // We need to figure out the scalar type read by the operand. e.g. the MFMA
     // operand will be AReg_128, and we want to check if it's compatible with 
an
     // AReg_32 constant.
@@ -1032,17 +1069,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
       OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1);
       break;
     default:
-      return nullptr;
+      return false;
     }
 
     if (!TRI->getCommonSubClass(OpRC, SplatRC))
-      return nullptr;
+      return false;
   }
 
-  if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
-    return nullptr;
+  MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal);
+  if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp))
+    return false;
 
-  return SplatVal;
+  return true;
 }
 
 bool SIFoldOperandsImpl::tryToFoldACImm(
@@ -1120,7 +1158,7 @@ void SIFoldOperandsImpl::foldOperand(
     Register RegSeqDstReg = UseMI->getOperand(0).getReg();
     unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();
 
-    MachineOperand *SplatVal;
+    int64_t SplatVal;
     const TargetRegisterClass *SplatRC;
     std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);
 
@@ -1131,10 +1169,9 @@ void SIFoldOperandsImpl::foldOperand(
       MachineInstr *RSUseMI = RSUse->getParent();
       unsigned OpNo = RSUseMI->getOperandNo(RSUse);
 
-      if (SplatVal) {
-        if (MachineOperand *Foldable =
-                tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
-          FoldableDef SplatDef(*Foldable, SplatRC);
+      if (SplatRC) {
+        if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
+          FoldableDef SplatDef(SplatVal, SplatRC);
           appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef);
           continue;
         }
diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll 
b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
index 5d5dc01439fe4..a9cffd6e1c943 100644
--- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
+++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll
@@ -165,19 +165,9 @@ bb:
 }
 
 ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1:
-; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
 ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
 ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
 ; GCN:    global_store_dwordx4
 ; GCN:    global_store_dwordx4
@@ -190,19 +180,9 @@ bb:
 }
 
 ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1:
-; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
 ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
 ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
 ; GCN:    global_store_dwordx4
 ; GCN:    global_store_dwordx4
@@ -215,18 +195,9 @@ bb:
 }
 
 ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64:
-; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}}
-; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
-; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
-
-; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
 ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
-; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
+; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], 
v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
 ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 
v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
 ; GCN:    global_store_dwordx4
 ; GCN:    global_store_dwordx4

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to