This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 006b11df04 [RUNTIME] Make systemlib unique per prefix (#14887)
006b11df04 is described below
commit 006b11df046f3aecc076384fb96cb6733287b25f
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri May 19 08:04:44 2023 -0400
[RUNTIME] Make systemlib unique per prefix (#14887)
This PR enhances systemlib to make it unique per prefix.
Can help reduce flaky problems in multiple loading of same lib.
---
src/runtime/system_library.cc | 53 +++++++++++++++++------
tests/python/unittest/test_target_codegen_blob.py | 2 +
2 files changed, 42 insertions(+), 13 deletions(-)
diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc
index be9257e53f..55335649a7 100644
--- a/src/runtime/system_library.cc
+++ b/src/runtime/system_library.cc
@@ -32,30 +32,30 @@
namespace tvm {
namespace runtime {
-class SystemLibraryRegistry {
+class SystemLibSymbolRegistry {
public:
void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_);
- auto it = tbl_.find(name);
- if (it != tbl_.end() && ptr != it->second) {
+ auto it = symbol_table_.find(name);
+ if (it != symbol_table_.end() && ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a
different address " << ptr
<< "->" << it->second;
}
- tbl_[name] = ptr;
+ symbol_table_[name] = ptr;
}
void* GetSymbol(const char* name) {
std::lock_guard<std::mutex> lock(mutex_);
- auto it = tbl_.find(name);
- if (it != tbl_.end()) {
+ auto it = symbol_table_.find(name);
+ if (it != symbol_table_.end()) {
return it->second;
} else {
return nullptr;
}
}
- static SystemLibraryRegistry* Global() {
- static SystemLibraryRegistry* inst = new SystemLibraryRegistry();
+ static SystemLibSymbolRegistry* Global() {
+ static SystemLibSymbolRegistry* inst = new SystemLibSymbolRegistry();
return inst;
}
@@ -63,7 +63,7 @@ class SystemLibraryRegistry {
// Internal mutex
std::mutex mutex_;
// Internal symbol table
- std::unordered_map<std::string, void*> tbl_;
+ std::unordered_map<std::string, void*> symbol_table_;
};
class SystemLibrary : public Library {
@@ -80,22 +80,49 @@ class SystemLibrary : public Library {
}
private:
- SystemLibraryRegistry* reg_ = SystemLibraryRegistry::Global();
+ SystemLibSymbolRegistry* reg_ = SystemLibSymbolRegistry::Global();
std::string symbol_prefix_;
};
+class SystemLibModuleRegistry {
+ public:
+ runtime::Module GetOrCreateModule(std::string symbol_prefix) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ auto it = lib_map_.find(symbol_prefix);
+ if (it != lib_map_.end()) {
+ return it->second;
+ } else {
+ auto mod =
CreateModuleFromLibrary(make_object<SystemLibrary>(symbol_prefix));
+ lib_map_[symbol_prefix] = mod;
+ return mod;
+ }
+ }
+
+ static SystemLibModuleRegistry* Global() {
+ static SystemLibModuleRegistry* inst = new SystemLibModuleRegistry();
+ return inst;
+ }
+
+ private:
+ // Internal mutex
+ std::mutex mutex_;
+ // we need to make sure each lib map have an unique
+ // copy through out the entire lifetime of the process
+ // so the cached PackedFunc in the system do not get out dated.
+ std::unordered_map<std::string, runtime::Module> lib_map_;
+};
+
TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body([](TVMArgs args,
TVMRetValue* rv) {
std::string symbol_prefix = "";
if (args.size() != 0) {
symbol_prefix = args[0].operator std::string();
}
- auto mod =
CreateModuleFromLibrary(make_object<SystemLibrary>(symbol_prefix));
- *rv = mod;
+ *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix);
});
} // namespace runtime
} // namespace tvm
int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
- tvm::runtime::SystemLibraryRegistry::Global()->RegisterSymbol(name, ptr);
+ tvm::runtime::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr);
return 0;
}
diff --git a/tests/python/unittest/test_target_codegen_blob.py
b/tests/python/unittest/test_target_codegen_blob.py
index d7683fd68c..2848c875a4 100644
--- a/tests/python/unittest/test_target_codegen_blob.py
+++ b/tests/python/unittest/test_target_codegen_blob.py
@@ -122,6 +122,8 @@ def test_cuda_multi_lib():
b_nd = tvm.nd.array(a_np, dev)
syslibA = tvm.runtime.system_lib("modA_")
syslibB = tvm.runtime.system_lib("modB_")
+ # reload same lib twice
+ syslibA = tvm.runtime.system_lib("modA_")
syslibA["my_inplace_update"](a_nd)
syslibB["my_inplace_update"](b_nd)
np.testing.assert_equal(a_nd.numpy(), a_np + 1)