llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-directx Author: Justin Bogner (bogner) <details> <summary>Changes</summary> This introduces an anonymous class "OpLowerer" to help with lowering DXIL ops, and moves the DXILOpBuilder there instead of creating a new one for every operation. DXILOpBuilder is also changed to own its IRBuilder, since that makes it simpler to ensure that it isn't misused. --- Full diff: https://github.com/llvm/llvm-project/pull/104248.diff 3 Files Affected: - (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+3-4) - (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+6-3) - (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+64-45) ``````````diff diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 987437619f08e..7d2b40cc515cc 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -11,7 +11,6 @@ #include "DXILOpBuilder.h" #include "DXILConstants.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" @@ -335,7 +334,7 @@ namespace dxil { // Triple is well-formed or that the target is supported since these checks // would have been done at the time the module M is constructed in the earlier // stages of compilation. -DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) { +DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) { Triple TT(Triple(M.getTargetTriple())); DXILVersion = TT.getDXILVersion(); ShaderStage = TT.getEnvironment(); @@ -417,10 +416,10 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, // We need to inject the opcode as the first argument. SmallVector<Value *> OpArgs; - OpArgs.push_back(B.getInt32(llvm::to_underlying(OpCode))); + OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode))); OpArgs.append(Args.begin(), Args.end()); - return B.CreateCall(DXILFn, OpArgs); + return IRB.CreateCall(DXILFn, OpArgs); } CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args, diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index 5d83357f7a2e9..483d5ddc8b619 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -14,8 +14,9 @@ #include "DXILConstants.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/TargetParser/Triple.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/Error.h" +#include "llvm/TargetParser/Triple.h" namespace llvm { class Module; @@ -29,7 +30,9 @@ namespace dxil { class DXILOpBuilder { public: - DXILOpBuilder(Module &M, IRBuilderBase &B); + DXILOpBuilder(Module &M); + + IRBuilder<> &getIRB() { return IRB; } /// Create a call instruction for the given DXIL op. The arguments /// must be valid for an overload of the operation. @@ -51,7 +54,7 @@ class DXILOpBuilder { Type *OverloadType = nullptr); Module &M; - IRBuilderBase &B; + IRBuilder<> IRB; VersionTuple DXILVersion; Triple::EnvironmentType ShaderStage; }; diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 5f84cdcfda6de..e458720fcd6e9 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -73,67 +73,84 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig, return NewOperands; } -static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { - IRBuilder<> B(M.getContext()); - DXILOpBuilder OpBuilder(M, B); - for (User *U : make_early_inc_range(F.users())) { - CallInst *CI = dyn_cast<CallInst>(U); - if (!CI) - continue; - - SmallVector<Value *> Args; - B.SetInsertPoint(CI); - if (isVectorArgExpansion(F)) { - SmallVector<Value *> NewArgs = argVectorFlatten(CI, B); - Args.append(NewArgs.begin(), NewArgs.end()); - } else - Args.append(CI->arg_begin(), CI->arg_end()); - - Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args, - F.getReturnType()); - if (Error E = OpCallOrErr.takeError()) { - std::string Message(toString(std::move(E))); - DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message, - CI->getDebugLoc()); - M.getContext().diagnose(Diag); - continue; +namespace { +class OpLowerer { + Module &M; + DXILOpBuilder OpBuilder; + +public: + OpLowerer(Module &M) : M(M), OpBuilder(M) {} + + void replaceFunction(Function &F, + llvm::function_ref<Error(CallInst *CI)> ReplaceCall) { + for (User *U : make_early_inc_range(F.users())) { + CallInst *CI = dyn_cast<CallInst>(U); + if (!CI) + continue; + + if (Error E = ReplaceCall(CI)) { + std::string Message(toString(std::move(E))); + DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message, + CI->getDebugLoc()); + M.getContext().diagnose(Diag); + continue; + } } - CallInst *OpCall = *OpCallOrErr; + if (F.user_empty()) + F.eraseFromParent(); + } - CI->replaceAllUsesWith(OpCall); - CI->eraseFromParent(); + void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) { + bool IsVectorArgExpansion = isVectorArgExpansion(F); + replaceFunction(F, [&](CallInst *CI) -> Error { + SmallVector<Value *> Args; + OpBuilder.getIRB().SetInsertPoint(CI); + if (IsVectorArgExpansion) { + SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB()); + Args.append(NewArgs.begin(), NewArgs.end()); + } else + Args.append(CI->arg_begin(), CI->arg_end()); + + Expected<CallInst *> OpCall = + OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType()); + if (Error E = OpCall.takeError()) + return E; + + CI->replaceAllUsesWith(*OpCall); + CI->eraseFromParent(); + return Error::success(); + }); } - if (F.user_empty()) - F.eraseFromParent(); -} -static bool lowerIntrinsics(Module &M) { - bool Updated = false; + bool lowerIntrinsics() { + bool Updated = false; - for (Function &F : make_early_inc_range(M.functions())) { - if (!F.isDeclaration()) - continue; - Intrinsic::ID ID = F.getIntrinsicID(); - switch (ID) { - default: - continue; + for (Function &F : make_early_inc_range(M.functions())) { + if (!F.isDeclaration()) + continue; + Intrinsic::ID ID = F.getIntrinsicID(); + switch (ID) { + default: + continue; #define DXIL_OP_INTRINSIC(OpCode, Intrin) \ case Intrin: \ - lowerIntrinsic(OpCode, F, M); \ + replaceFunctionWithOp(F, OpCode); \ break; #include "DXILOperation.inc" + } + Updated = true; } - Updated = true; + return Updated; } - return Updated; -} +}; +} // namespace namespace { /// A pass that transforms external global definitions into declarations. class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { public: PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { - if (lowerIntrinsics(M)) + if (OpLowerer(M).lowerIntrinsics()) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } @@ -143,7 +160,9 @@ class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { namespace { class DXILOpLoweringLegacy : public ModulePass { public: - bool runOnModule(Module &M) override { return lowerIntrinsics(M); } + bool runOnModule(Module &M) override { + return OpLowerer(M).lowerIntrinsics(); + } StringRef getPassName() const override { return "DXIL Op Lowering"; } DXILOpLoweringLegacy() : ModulePass(ID) {} `````````` </details> https://github.com/llvm/llvm-project/pull/104248 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits