This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d32dea800b [METAL] Update metal runtime to directly store kernel map
(#14727)
d32dea800b is described below
commit d32dea800baeda14d144d0524a5da435c9cb160b
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Apr 26 19:04:19 2023 -0400
[METAL] Update metal runtime to directly store kernel map (#14727)
This PR updates the metal runtime storage format to
directly store the kernel map.
This will enable more robust support to
leverage metallib binary format which may not
be compatible with previous string split.
It changes the binary format of the metal module.
We also added a version to enable easier future update.
---
src/runtime/metal/metal_module.h | 11 ++--
src/runtime/metal/metal_module.mm | 108 ++++++++++++++++++-------------------
src/target/opt/build_metal_off.cc | 7 +--
src/target/source/codegen_metal.cc | 31 ++++++-----
4 files changed, 79 insertions(+), 78 deletions(-)
diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h
index 77cdf64df8..d01523b1fa 100644
--- a/src/runtime/metal/metal_module.h
+++ b/src/runtime/metal/metal_module.h
@@ -41,13 +41,14 @@ static constexpr const int kMetalMaxNumDevice = 32;
/*!
* \brief create a metal module from data.
*
- * \param data The data content.
- * \param fmt The format of the data, can be "metal" or "metallib"
+ * \param smap The map from name to each shader kernel.
* \param fmap The map function information map of each function.
- * \param source Optional, source file
+ * \param fmt The format of the source, can be "metal" or "metallib"
+ * \param source Optional, source file, concatenaed for debug dump
*/
-Module MetalModuleCreate(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
std::string source);
+Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap,
std::string fmt,
+ std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index a5eddf3a95..aef6cf5ebe 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -30,23 +30,26 @@
#include "../file_utils.h"
#include "../meta_data.h"
#include "../pack_args.h"
-#include "../source_utils.h"
#include "../thread_storage_scope.h"
#include "metal_common.h"
namespace tvm {
namespace runtime {
+// The version of metal module
+// for future compatibility checking
+// bump when we change the binary format.
+static constexpr const char* kMetalModuleVersion = "0.1.0";
+
// Module to support thread-safe multi-GPU execution.
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class MetalModuleNode final : public runtime::ModuleNode {
public:
- explicit MetalModuleNode(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
std::string source)
- : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
- parsed_kernels_ = SplitKernels(data);
- }
+ explicit MetalModuleNode(std::unordered_map<std::string, std::string> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap,
std::string fmt,
+ std::string source)
+ : smap_(smap), fmap_(fmap), fmt_(fmt), source_(source) {}
const char* type_key() const final { return "metal"; }
/*! \brief Get the property of the runtime module. */
@@ -57,27 +60,19 @@ class MetalModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) final;
void SaveToFile(const std::string& file_name, const std::string& format)
final {
- std::string fmt = GetFileFormat(file_name, format);
- ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
- std::string meta_file = GetMetaFilePath(file_name);
- SaveMetaDataToFile(meta_file, fmap_);
- SaveBinaryToFile(file_name, data_);
+ LOG(FATAL) << "Do not support save to file, use save to binary and export
instead";
}
void SaveToBinary(dmlc::Stream* stream) final {
- stream->Write(fmt_);
+ std::string version = kMetalModuleVersion;
+ stream->Write(version);
+ stream->Write(smap_);
stream->Write(fmap_);
- stream->Write(data_);
+ stream->Write(fmt_);
}
std::string GetSource(const std::string& format) final {
- if (format == fmt_) return data_;
- if (source_.length() != 0) {
- return source_;
- } else if (fmt_ == "metal") {
- return data_;
- } else {
- return "";
- }
+ // return text source if available.
+ return source_;
}
// get a from primary context in device_id
@@ -95,15 +90,11 @@ class MetalModuleNode final : public runtime::ModuleNode {
// compile
NSError* err_msg = nil;
id<MTLLibrary> lib = nil;
- std::string source;
- auto kernel = parsed_kernels_.find(func_name);
- // If we cannot find this kernel in parsed_kernels_, it means that all
kernels going together
- // without explicit separator. In this case we use data_ with all kernels.
It done for backward
- // compatibility.
- if (kernel != parsed_kernels_.end())
- source = kernel->second;
- else
- source = data_;
+ auto kernel = smap_.find(func_name);
+ // Directly lookup kernels
+ ICHECK(kernel != smap_.end());
+ const std::string& source = kernel->second;
+
if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_3;
@@ -115,7 +106,8 @@ class MetalModuleNode final : public runtime::ModuleNode {
error:&err_msg];
[opts dealloc];
if (lib == nil) {
- LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg
localizedDescription] UTF8String];
+ LOG(FATAL) << "Fail to compile metal source:"
+ << [[err_msg localizedDescription] UTF8String];
}
if (err_msg != nil) {
LOG(INFO) << "Warning: " << [[err_msg localizedDescription]
UTF8String];
@@ -161,20 +153,18 @@ class MetalModuleNode final : public runtime::ModuleNode {
}
}
};
- // the binary data
- std::string data_;
- // The format
- std::string fmt_;
+ // the source shader data, can be mtl or binary
+ std::unordered_map<std::string, std::string> smap_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
+ // The format
+ std::string fmt_;
// The source
std::string source_;
// function information.
std::vector<DeviceEntry> finfo_;
// internal mutex when updating the module
std::mutex mutex_;
- // parsed kernel data
- std::unordered_map<std::string, std::string> parsed_kernels_;
};
// a wrapped function class to get packed func.
@@ -272,39 +262,45 @@ PackedFunc MetalModuleNode::GetFunction(const
std::string& name,
return pf;
}
-Module MetalModuleCreate(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
+Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap,
std::string fmt,
+ std::string source) {
ObjectPtr<Object> n;
AUTORELEASEPOOL {
metal::MetalWorkspace::Global()->Init();
- n = make_object<MetalModuleNode>(data, fmt, fmap, source);
+ n = make_object<MetalModuleNode>(smap, fmap, fmt, source);
};
return Module(n);
}
-// Load module from module.
-Module MetalModuleLoadFile(const std::string& file_name, const std::string&
format) {
- std::string data;
- std::unordered_map<std::string, FunctionInfo> fmap;
- std::string fmt = GetFileFormat(file_name, format);
- std::string meta_file = GetMetaFilePath(file_name);
- LoadBinaryFromFile(file_name, &data);
- LoadMetaDataFromFile(meta_file, &fmap);
- return MetalModuleCreate(data, fmt, fmap, "");
-}
+TVM_REGISTER_GLOBAL("runtime.module.create_metal_module")
+ .set_body_typed([](Map<String, String> smap, std::string fmap_json,
std::string fmt,
+ std::string source) {
+ std::istringstream stream(fmap_json);
+ std::unordered_map<std::string, FunctionInfo> fmap;
+ dmlc::JSONReader reader(&stream);
+ reader.Read(&fmap);
+ return MetalModuleCreate(
+ std::unordered_map<std::string, std::string>(smap.begin(),
smap.end()), fmap, fmt,
+ source);
+ });
Module MetalModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
- std::string data;
+ // version is reserved for future changes and
+ // is discarded for now
+ std::string ver;
+ std::unordered_map<std::string, std::string> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
- stream->Read(&fmt);
+
+ stream->Read(&ver);
+ stream->Read(&smap);
stream->Read(&fmap);
- stream->Read(&data);
- return MetalModuleCreate(data, fmt, fmap, "");
-}
+ stream->Read(&fmt);
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile);
+ return MetalModuleCreate(smap, fmap, fmt, "");
+}
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary);
} // namespace runtime
diff --git a/src/target/opt/build_metal_off.cc
b/src/target/opt/build_metal_off.cc
index 3cfe1316e7..555aa5002f 100644
--- a/src/target/opt/build_metal_off.cc
+++ b/src/target/opt/build_metal_off.cc
@@ -26,10 +26,11 @@
namespace tvm {
namespace runtime {
-Module MetalModuleCreate(std::string data, std::string fmt,
- std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
+Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap,
std::string fmt,
+ std::string source) {
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
- return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal");
+ return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal");
}
} // namespace runtime
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 767311cb5a..44da240dd5 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -24,6 +24,7 @@
#include <algorithm>
#include <string>
+#include <unordered_map>
#include <vector>
#include "../../runtime/metal/metal_module.h"
@@ -336,33 +337,35 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
- std::stringstream code;
- std::stringstream source;
- std::string fmt = "metal";
+ std::ostringstream source_maker;
+ std::unordered_map<std::string, std::string> smap;
+ const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
+ std::string fmt = fmetal_compile ? "metallib" : "metal";
+
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only
take PrimFunc";
- code << "// Function: " << kv.first->name_hint << std::endl;
+ auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined());
+ std::string func_name = global_symbol.value();
+
+ source_maker << "// Function: " << func_name << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+
cg.AddFunction(f);
std::string fsource = cg.Finish();
- if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
- source << fsource;
- fsource = (*f)(fsource).operator std::string();
- fmt = "metallib";
+ source_maker << fsource << "\n";
+ if (fmetal_compile) {
+ fsource = (*fmetal_compile)(fsource).operator std::string();
}
- code << fsource;
+ smap[func_name] = fsource;
}
- std::string code_str = code.str();
- if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) {
- code_str = (*f)(code_str).operator std::string();
- }
- return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str());
+ return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt,
source_maker.str());
}
TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);