llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-amdgpu Author: Matt Arsenault (arsenm) <details> <summary>Changes</summary> Recognize a reg_sequence with 32-bit elements that produce a 64-bit splat value. This enables folding f64 constants into mfma operands --- Full diff: https://github.com/llvm/llvm-project/pull/140878.diff 2 Files Affected: - (modified) llvm/lib/Target/AMDGPU/SIFoldOperands.cpp (+70-33) - (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll (+6-35) ``````````diff diff --git a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp index eb7fb94e25f5c..70e3974bb22b4 100644 --- a/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp +++ b/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp @@ -227,12 +227,12 @@ class SIFoldOperandsImpl { getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs, Register UseReg) const; - std::pair<MachineOperand *, const TargetRegisterClass *> + std::pair<int64_t, const TargetRegisterClass *> isRegSeqSplat(MachineInstr &RegSeg) const; - MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx, - MachineOperand *SplatVal, - const TargetRegisterClass *SplatRC) const; + bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx, + int64_t SplatVal, + const TargetRegisterClass *SplatRC) const; bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI, unsigned UseOpIdx, @@ -967,15 +967,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit( return getRegSeqInit(*Def, Defs); } -std::pair<MachineOperand *, const TargetRegisterClass *> +std::pair<int64_t, const TargetRegisterClass *> SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const { SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs; const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs); if (!SrcRC) return {}; - // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize - // every other other element is 0 for 64-bit immediates) + bool TryToMatchSplat64 = false; + int64_t Imm; for (unsigned I = 0, E = Defs.size(); I != E; ++I) { const MachineOperand *Op = Defs[I].first; @@ -987,38 +987,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const { Imm = SubImm; continue; } - if (Imm != SubImm) + + if (Imm != SubImm) { + if (I == 1 && (E & 1) == 0) { + // If we have an even number of inputs, there's a chance this is a + // 64-bit element splat broken into 32-bit pieces. + TryToMatchSplat64 = true; + break; + } + return {}; // Can only fold splat constants + } + } + + if (!TryToMatchSplat64) + return {Defs[0].first->getImm(), SrcRC}; + + // Fallback to recognizing 64-bit splats broken into 32-bit pieces + // (i.e. recognize every other other element is 0 for 64-bit immediates) + int64_t SplatVal64; + for (unsigned I = 0, E = Defs.size(); I != E; I += 2) { + const MachineOperand *Op0 = Defs[I].first; + const MachineOperand *Op1 = Defs[I + 1].first; + + if (!Op0->isImm() || !Op1->isImm()) + return {}; + + unsigned SubReg0 = Defs[I].second; + unsigned SubReg1 = Defs[I + 1].second; + + // Assume we're going to generally encounter reg_sequences with sorted + // subreg indexes, so reject any that aren't consecutive. + if (TRI->getChannelFromSubReg(SubReg0) + 1 != + TRI->getChannelFromSubReg(SubReg1)) + return {}; + + int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm()); + if (I == 0) + SplatVal64 = MergedVal; + else if (SplatVal64 != MergedVal) + return {}; } - return {Defs[0].first, SrcRC}; + const TargetRegisterClass *RC64 = TRI->getSubRegisterClass( + MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1); + + return {SplatVal64, RC64}; } -MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat( - MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal, +bool SIFoldOperandsImpl::tryFoldRegSeqSplat( + MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal, const TargetRegisterClass *SplatRC) const { const MCInstrDesc &Desc = UseMI->getDesc(); if (UseOpIdx >= Desc.getNumOperands()) - return nullptr; + return false; // Filter out unhandled pseudos. if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx)) - return nullptr; + return false; int16_t RCID = Desc.operands()[UseOpIdx].RegClass; if (RCID == -1) - return nullptr; + return false; + + const TargetRegisterClass *OpRC = TRI->getRegClass(RCID); // Special case 0/-1, since when interpreted as a 64-bit element both halves - // have the same bits. Effectively this code does not handle 64-bit element - // operands correctly, as the incoming 64-bit constants are already split into - // 32-bit sequence elements. - // - // TODO: We should try to figure out how to interpret the reg_sequence as a - // split 64-bit splat constant, or use 64-bit pseudos for materializing f64 - // constants. - if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) { - const TargetRegisterClass *OpRC = TRI->getRegClass(RCID); + // have the same bits. These are the only cases where a splat has the same + // interpretation for 32-bit and 64-bit splats. + if (SplatVal != 0 && SplatVal != -1) { // We need to figure out the scalar type read by the operand. e.g. the MFMA // operand will be AReg_128, and we want to check if it's compatible with an // AReg_32 constant. @@ -1032,17 +1069,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat( OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1); break; default: - return nullptr; + return false; } if (!TRI->getCommonSubClass(OpRC, SplatRC)) - return nullptr; + return false; } - if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal)) - return nullptr; + MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal); + if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp)) + return false; - return SplatVal; + return true; } bool SIFoldOperandsImpl::tryToFoldACImm( @@ -1120,7 +1158,7 @@ void SIFoldOperandsImpl::foldOperand( Register RegSeqDstReg = UseMI->getOperand(0).getReg(); unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm(); - MachineOperand *SplatVal; + int64_t SplatVal; const TargetRegisterClass *SplatRC; std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI); @@ -1131,10 +1169,9 @@ void SIFoldOperandsImpl::foldOperand( MachineInstr *RSUseMI = RSUse->getParent(); unsigned OpNo = RSUseMI->getOperandNo(RSUse); - if (SplatVal) { - if (MachineOperand *Foldable = - tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) { - FoldableDef SplatDef(*Foldable, SplatRC); + if (SplatRC) { + if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) { + FoldableDef SplatDef(SplatVal, SplatRC); appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef); continue; } diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll index 5d5dc01439fe4..a9cffd6e1c943 100644 --- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll @@ -165,19 +165,9 @@ bb: } ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1: -; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000 -; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]] -; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}} -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]] - -; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}} +; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}} ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3 -; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}} +; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}} ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0] ; GCN: global_store_dwordx4 ; GCN: global_store_dwordx4 @@ -190,19 +180,9 @@ bb: } ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1: -; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000 -; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]] -; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}} -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]] - -; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}} +; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}} ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3 -; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}} +; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}} ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0] ; GCN: global_store_dwordx4 ; GCN: global_store_dwordx4 @@ -215,18 +195,9 @@ bb: } ; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64: -; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}} -; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0 -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]] -; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]] -; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]] - -; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}} +; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}} ; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3 -; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}} +; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}} ; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0] ; GCN: global_store_dwordx4 ; GCN: global_store_dwordx4 `````````` </details> https://github.com/llvm/llvm-project/pull/140878 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits