This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 006b11df04 [RUNTIME] Make systemlib unique per prefix (#14887)
006b11df04 is described below

commit 006b11df046f3aecc076384fb96cb6733287b25f
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri May 19 08:04:44 2023 -0400

    [RUNTIME] Make systemlib unique per prefix (#14887)
    
    This PR enhances systemlib to make it unique per prefix.
    Can help reduce flaky problems in multiple loading of same lib.
---
 src/runtime/system_library.cc                     | 53 +++++++++++++++++------
 tests/python/unittest/test_target_codegen_blob.py |  2 +
 2 files changed, 42 insertions(+), 13 deletions(-)

diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc
index be9257e53f..55335649a7 100644
--- a/src/runtime/system_library.cc
+++ b/src/runtime/system_library.cc
@@ -32,30 +32,30 @@
 namespace tvm {
 namespace runtime {
 
-class SystemLibraryRegistry {
+class SystemLibSymbolRegistry {
  public:
   void RegisterSymbol(const std::string& name, void* ptr) {
     std::lock_guard<std::mutex> lock(mutex_);
-    auto it = tbl_.find(name);
-    if (it != tbl_.end() && ptr != it->second) {
+    auto it = symbol_table_.find(name);
+    if (it != symbol_table_.end() && ptr != it->second) {
       LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a 
different address " << ptr
                    << "->" << it->second;
     }
-    tbl_[name] = ptr;
+    symbol_table_[name] = ptr;
   }
 
   void* GetSymbol(const char* name) {
     std::lock_guard<std::mutex> lock(mutex_);
-    auto it = tbl_.find(name);
-    if (it != tbl_.end()) {
+    auto it = symbol_table_.find(name);
+    if (it != symbol_table_.end()) {
       return it->second;
     } else {
       return nullptr;
     }
   }
 
-  static SystemLibraryRegistry* Global() {
-    static SystemLibraryRegistry* inst = new SystemLibraryRegistry();
+  static SystemLibSymbolRegistry* Global() {
+    static SystemLibSymbolRegistry* inst = new SystemLibSymbolRegistry();
     return inst;
   }
 
@@ -63,7 +63,7 @@ class SystemLibraryRegistry {
   // Internal mutex
   std::mutex mutex_;
   // Internal symbol table
-  std::unordered_map<std::string, void*> tbl_;
+  std::unordered_map<std::string, void*> symbol_table_;
 };
 
 class SystemLibrary : public Library {
@@ -80,22 +80,49 @@ class SystemLibrary : public Library {
   }
 
  private:
-  SystemLibraryRegistry* reg_ = SystemLibraryRegistry::Global();
+  SystemLibSymbolRegistry* reg_ = SystemLibSymbolRegistry::Global();
   std::string symbol_prefix_;
 };
 
+class SystemLibModuleRegistry {
+ public:
+  runtime::Module GetOrCreateModule(std::string symbol_prefix) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = lib_map_.find(symbol_prefix);
+    if (it != lib_map_.end()) {
+      return it->second;
+    } else {
+      auto mod = 
CreateModuleFromLibrary(make_object<SystemLibrary>(symbol_prefix));
+      lib_map_[symbol_prefix] = mod;
+      return mod;
+    }
+  }
+
+  static SystemLibModuleRegistry* Global() {
+    static SystemLibModuleRegistry* inst = new SystemLibModuleRegistry();
+    return inst;
+  }
+
+ private:
+  // Internal mutex
+  std::mutex mutex_;
+  // we need to make sure each lib map have an unique
+  // copy through out the entire lifetime of the process
+  // so the cached PackedFunc in the system do not get out dated.
+  std::unordered_map<std::string, runtime::Module> lib_map_;
+};
+
 TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body([](TVMArgs args, 
TVMRetValue* rv) {
   std::string symbol_prefix = "";
   if (args.size() != 0) {
     symbol_prefix = args[0].operator std::string();
   }
-  auto mod = 
CreateModuleFromLibrary(make_object<SystemLibrary>(symbol_prefix));
-  *rv = mod;
+  *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix);
 });
 }  // namespace runtime
 }  // namespace tvm
 
 int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
-  tvm::runtime::SystemLibraryRegistry::Global()->RegisterSymbol(name, ptr);
+  tvm::runtime::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr);
   return 0;
 }
diff --git a/tests/python/unittest/test_target_codegen_blob.py 
b/tests/python/unittest/test_target_codegen_blob.py
index d7683fd68c..2848c875a4 100644
--- a/tests/python/unittest/test_target_codegen_blob.py
+++ b/tests/python/unittest/test_target_codegen_blob.py
@@ -122,6 +122,8 @@ def test_cuda_multi_lib():
         b_nd = tvm.nd.array(a_np, dev)
         syslibA = tvm.runtime.system_lib("modA_")
         syslibB = tvm.runtime.system_lib("modB_")
+        # reload same lib twice
+        syslibA = tvm.runtime.system_lib("modA_")
         syslibA["my_inplace_update"](a_nd)
         syslibB["my_inplace_update"](b_nd)
         np.testing.assert_equal(a_nd.numpy(), a_np + 1)

Reply via email to