================ @@ -98,21 +110,98 @@ void applySPIRVDistance(MachineInstr &MI, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR = MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry(); - auto RemoveAllUses = [&](Register Reg) { - SmallVector<MachineInstr *, 4> UsesToErase( - llvm::make_pointer_range(MRI.use_instructions(Reg))); - - // calling eraseFromParent to early invalidates the iterator. - for (auto *MIToErase : UsesToErase) { - GR->invalidateMachineInstr(MIToErase); - MIToErase->eraseFromParent(); - } - }; - RemoveAllUses(SubDestReg); // remove all uses of FSUB Result + removeAllUses(SubDestReg, MRI, GR); // remove all uses of FSUB Result GR->invalidateMachineInstr(SubInstr); SubInstr->eraseFromParent(); // remove FSUB instruction } +/// This match is part of a combine that +/// rewrites select(fcmp(dot(I, Ng), 0), N, 0 - N) to faceforward(N, I, Ng) +/// (vXf32 (g_select +/// (g_fcmp +/// (g_intrinsic dot(vXf32 I) (vXf32 Ng) +/// 0) +/// (vXf32 N) +/// (vXf32 g_fsub (0) (vXf32 N)))) +/// -> +/// (vXf32 (g_intrinsic faceforward +/// (vXf32 N) (vXf32 I) (vXf32 Ng))) +/// +bool matchSelectToFaceForward(MachineInstr &MI, MachineRegisterInfo &MRI) { + if (MI.getOpcode() != TargetOpcode::G_SELECT) + return false; + + // Check if select's condition is a comparison between a dot product and 0. + Register CondReg = MI.getOperand(1).getReg(); + MachineInstr *CondInstr = MRI.getVRegDef(CondReg); + if (!CondInstr || CondInstr->getOpcode() != TargetOpcode::G_FCMP) + return false; + + Register DotReg = CondInstr->getOperand(2).getReg(); + MachineInstr *DotInstr = MRI.getVRegDef(DotReg); + if (DotInstr->getOpcode() != TargetOpcode::G_FMUL && + (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC || + cast<GIntrinsic>(DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot)) + return false; + + Register CondZeroReg = CondInstr->getOperand(3).getReg(); + MachineInstr *CondZeroInstr = MRI.getVRegDef(CondZeroReg); + if (CondZeroInstr->getOpcode() != TargetOpcode::G_FCONSTANT || + !CondZeroInstr->getOperand(1).getFPImm()->isZero()) + return false; + + // Check if select's false operand is the negation of the true operand. + Register TrueReg = MI.getOperand(2).getReg(); + Register FalseReg = MI.getOperand(3).getReg(); + MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg); + if (FalseInstr->getOpcode() != TargetOpcode::G_FNEG) + return false; + if (TrueReg != FalseInstr->getOperand(1).getReg()) + return false; ---------------- kmpeng wrote:
Yeah, I was also wondering if this way of checking would be a problem since there's multiple ways to represent 2 operands being negations (another example being `G_FSUB`). I'm also still learning and have not been able to come up with a generic way to check this. @farzonl Do you have any thoughts? https://github.com/llvm/llvm-project/pull/139959 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits