This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bee073b0c8 [LLVM] Minor refactor to LLVMModuleNode::SaveToFile (#15139)
bee073b0c8 is described below
commit bee073b0c8e8625216184a2dbb0204c0a376fc26
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 22 07:22:38 2023 -0500
[LLVM] Minor refactor to LLVMModuleNode::SaveToFile (#15139)
Previously, the `#if TVM_LLVM_VERSION` checks made it difficult to
determine the logic of LLVMModuleNode::SaveToFile while debugging.
This commit pulls out the preprocessor directives into wrapper
functions that maintain the same compatibility, making it easier to
follow the logic of the `SaveToFile` function.
---
src/target/llvm/llvm_module.cc | 78 ++++++++++++++++++++++--------------------
1 file changed, 40 insertions(+), 38 deletions(-)
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 4bb36ad284..85750fbf14 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -180,57 +180,59 @@ PackedFunc LLVMModuleNode::GetFunction(const String&
name, const ObjectPtr<Objec
return WrapPackedFunc(faddr, sptr_to_self);
}
-void LLVMModuleNode::SaveToFile(const String& file_name_str, const String&
format) {
- std::string file_name = file_name_str;
- std::string fmt = runtime::GetFileFormat(file_name, format);
- std::error_code ecode;
+namespace {
#if TVM_LLVM_VERSION <= 70
- llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None);
+constexpr auto llvm_open_output_flag = llvm::sys::fs::F_None;
#else
- llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None);
+constexpr auto llvm_open_output_flag = llvm::sys::fs::OF_None;
#endif
- ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " <<
ecode.message();
- if (fmt == "o" || fmt == "obj") {
- With<LLVMTarget> llvm_target(*llvm_instance_,
LLVMTarget::GetTargetMetadata(*module_));
+
#if TVM_LLVM_VERSION <= 60
- std::unique_ptr<llvm::Module> m = llvm::CloneModule(module_);
+std::unique_ptr<llvm::Module> CloneLLVMModule(llvm::Module* mod) { return
llvm::CloneModule(mod); }
#else
- std::unique_ptr<llvm::Module> m = llvm::CloneModule(*module_);
+std::unique_ptr<llvm::Module> CloneLLVMModule(llvm::Module* mod) { return
llvm::CloneModule(*mod); }
#endif
- llvm::legacy::PassManager pass;
- llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
-#if TVM_LLVM_VERSION <= 60
- ICHECK(tm->addPassesToEmitFile(pass, dest,
llvm::TargetMachine::CGFT_ObjectFile) == 0)
- << "Cannot emit target CGFT_ObjectFile";
-#elif TVM_LLVM_VERSION <= 90
- ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr,
llvm::TargetMachine::CGFT_ObjectFile) == 0)
- << "Cannot emit target CGFT_ObjectFile";
+
+#if TVM_LLVM_VERSION <= 90
+constexpr auto llvm_object_file_target = llvm::TargetMachine::CGFT_ObjectFile;
+constexpr auto llvm_assembly_file_target =
llvm::TargetMachine::CGFT_AssemblyFile;
#else
- ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile)
== 0)
- << "Cannot emit target CGFT_ObjectFile";
+constexpr auto llvm_object_file_target = llvm::CGFT_ObjectFile;
+constexpr auto llvm_assembly_file_target = llvm::CGFT_AssemblyFile;
#endif
- pass.run(*m);
- } else if (fmt == "s" || fmt == "asm") {
- With<LLVMTarget> llvm_target(*llvm_instance_,
LLVMTarget::GetTargetMetadata(*module_));
+
+bool LLVMAddPassesToEmitFile(llvm::TargetMachine* tm,
llvm::legacy::PassManager* pm,
+ llvm::raw_fd_ostream* dest,
+ decltype(llvm_object_file_target)
llvm_file_target) {
#if TVM_LLVM_VERSION <= 60
- std::unique_ptr<llvm::Module> m = llvm::CloneModule(module_);
+ return tm->addPassesToEmitFile(*pm, *dest, llvm_file_target);
#else
- std::unique_ptr<llvm::Module> m = llvm::CloneModule(*module_);
+ return tm->addPassesToEmitFile(*pm, *dest, nullptr, llvm_file_target);
#endif
+}
+
+} // namespace
+
+void LLVMModuleNode::SaveToFile(const String& file_name_str, const String&
format) {
+ // CHECK(imports_.empty()) << "SaveToFile does not handle imported modules";
+ std::string file_name = file_name_str;
+ std::string fmt = runtime::GetFileFormat(file_name, format);
+ std::error_code ecode;
+ llvm::raw_fd_ostream dest(file_name, ecode, llvm_open_output_flag);
+ ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " <<
ecode.message();
+ bool is_obj_file = fmt == "o" || fmt == "obj";
+ bool is_asm_file = fmt == "s" || fmt == "asm";
+ if (is_obj_file || is_asm_file) {
+ auto llvm_file_target = is_obj_file ? llvm_object_file_target :
llvm_assembly_file_target;
+
+ With<LLVMTarget> llvm_target(*llvm_instance_,
LLVMTarget::GetTargetMetadata(*module_));
llvm::legacy::PassManager pass;
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
-#if TVM_LLVM_VERSION <= 60
- ICHECK(tm->addPassesToEmitFile(pass, dest,
llvm::TargetMachine::CGFT_AssemblyFile) == 0)
- << "Cannot emit target CGFT_AssemblyFile";
-#elif TVM_LLVM_VERSION <= 90
- ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr,
llvm::TargetMachine::CGFT_AssemblyFile) ==
- 0)
- << "Cannot emit target CGFT_AssemblyFile";
-#else
- ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr,
llvm::CGFT_AssemblyFile) == 0)
- << "Cannot emit target CGFT_AssemblyFile";
-#endif
- pass.run(*m);
+
+ auto err = LLVMAddPassesToEmitFile(tm, &pass, &dest, llvm_file_target);
+ ICHECK(!err) << "Cannot emit target CGFT_ObjectFile";
+
+ pass.run(*CloneLLVMModule(module_));
} else if (fmt == "ll") {
module_->print(dest, nullptr);
} else if (fmt == "bc") {