This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cb39cc1a97 [Unity] Support multilib relax build (#14873)
cb39cc1a97 is described below

commit cb39cc1a97d4f1650795a77de3a96c9aa93769ad
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 18 03:46:08 2023 -0400

    [Unity] Support multilib relax build (#14873)
    
    * [Unity] AttachGlobalSymbol support multilib
    
    This PR makes AttachGlobalSymbol to be aware of
    system_lib_prefix and enable multi-lib support.
    
    * Fix conflicting temp concurrent access
    
    * Temp enable a legacy behavior that impacts MS relax integration
---
 python/tvm/relax/vm_build.py                       |   1 +
 src/relax/backend/vm/codegen_vm.cc                 |   6 +-
 src/relax/backend/vm/codegen_vm_tir.cc             |   8 +-
 src/relax/transform/attach_global_symbol.cc        |  48 +++++----
 .../relax/test_transform_attach_global_symbol.py   |  87 +++++++++++-----
 tests/python/relax/test_vm_build.py                | 110 +++++++++++++++++----
 6 files changed, 185 insertions(+), 75 deletions(-)

diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index a833939833..7e4fd85987 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -339,6 +339,7 @@ def build(
 
 def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule:
     tir_mod = IRModule({})
+    tir_mod = tir_mod.with_attrs(mod.attrs)
     for gv in mod.get_global_vars():
         if isinstance(mod[gv], PrimFunc):
             tir_mod[gv] = mod[gv]
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index c44300907f..3fbe246cd3 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -67,14 +67,14 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
       : builder_(builder), ctx_mod_(ctx_mod) {}
 
   static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
-    IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>());
+    IRModule res_mod = mod;
+    res_mod.CopyOnWrite();
     CodeGenVM codegen(builder, mod);
     // Remove relax function and turn into TIR func.
     for (const auto& [gvar, f] : mod->functions) {
       if (auto* func = f.as<FunctionNode>()) {
         codegen.Codegen(GetRef<Function>(func));
-      } else {
-        res_mod->Add(gvar, f);
+        res_mod->Remove(gvar);
       }
     }
     return res_mod;
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc 
b/src/relax/backend/vm/codegen_vm_tir.cc
index 276632a917..9ac65f6f6e 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -53,7 +53,9 @@ using vm::VMFuncInfo;
 class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
  public:
   explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod)
-      : builder_(builder), ctx_mod_(ctx_mod) {}
+      : builder_(builder), ctx_mod_(ctx_mod) {
+    system_lib_prefix_ = 
ctx_mod_->GetAttr<String>(tvm::attr::kSystemLibPrefix);
+  }
 
   static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
     // create a new copy
@@ -189,7 +191,7 @@ class CodeGenVMTIR : public 
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
     Type ret_type = VoidType();
     Array<tir::Var> tir_params = {ctx_ptr_, reg_anylist_handle_, 
const_anylist_handle_,
                                   func_anylist_handle_};
-    String tir_func_name = "__vmtir__" + gsymbol.value();
+    String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + 
gsymbol.value();
     tir::PrimFunc tir_func(tir_params, body, ret_type, {});
     tir_func = WithAttr(tir_func, "global_symbol", tir_func_name);
     registers_num_ = 0;
@@ -506,6 +508,8 @@ class CodeGenVMTIR : public 
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
   std::unordered_map<Var, Optional<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> 
var_map_;
   /*! \brief the context module. */
   IRModule ctx_mod_;
+  /*! \brief system lib prefix */
+  Optional<String> system_lib_prefix_;
   /*! \brief Cache ops that need to be frequently used later to reduce lookup 
overhead. */
   const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
   const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
diff --git a/src/relax/transform/attach_global_symbol.cc 
b/src/relax/transform/attach_global_symbol.cc
index be779e97bc..9b2a561c7f 100644
--- a/src/relax/transform/attach_global_symbol.cc
+++ b/src/relax/transform/attach_global_symbol.cc
@@ -21,48 +21,44 @@
  * \brief Attach global_symbol to Relax functions and TIR Primfuncs for 
codegen.
  */
 
+#include <tvm/ir/module.h>
 #include <tvm/relax/transform.h>
 #include <tvm/tir/function.h>
 
 namespace tvm {
 namespace relax {
+namespace transform {
+
+Pass AttachGlobalSymbol() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
+                                                                            
PassContext pc) {
+    mod.CopyOnWrite();
 
-class GlobalSymbolAttacher {
- public:
-  explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {}
+    String c_prefix = 
mod->GetAttr<String>(tvm::attr::kSystemLibPrefix).value_or("");
+    std::vector<std::pair<GlobalVar, BaseFunc> > updates;
 
-  IRModule Attach() {
-    IRModule ret;
-    for (auto& p : mod_->functions) {
+    for (auto& p : mod->functions) {
       BaseFunc func = p.second;
+      // TODO(tvm-team): re-enable once fix relax integration part
+      // if (func->GetAttr<String>(tvm::attr::kGlobalSymbol)) continue;
       if (auto* prim_func = func.as<tir::PrimFuncNode>()) {
-        func = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol", 
p.first->name_hint);
+        updates.emplace_back(p.first,
+                             WithAttr(GetRef<tir::PrimFunc>(prim_func), 
tvm::attr::kGlobalSymbol,
+                                      c_prefix + p.first->name_hint));
       } else if (auto* relax_func = func.as<FunctionNode>()) {
-        func = WithAttr(GetRef<Function>(relax_func), "global_symbol", 
p.first->name_hint);
-      } else {
-        LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey();
-        throw;
+        updates.emplace_back(p.first, WithAttr(GetRef<Function>(relax_func),
+                                               tvm::attr::kGlobalSymbol, 
p.first->name_hint));
       }
-      ret->Add(p.first, func);
     }
-    return ret;
-  }
-
- private:
-  IRModule mod_;
-};
-
-namespace transform {
-
-Pass AttachGlobalSymbol() {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
-      [=](IRModule mod, PassContext pc) { return 
GlobalSymbolAttacher(mod).Attach(); };
+    for (const auto& pair : updates) {
+      mod->Add(pair.first, pair.second, true);
+    }
+    return mod;
+  };
   return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {});
 }
 
 
TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol);
-
 }  // namespace transform
-
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py 
b/tests/python/relax/test_transform_attach_global_symbol.py
index 0937eed25a..035e21609d 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -21,35 +21,36 @@ from tvm import tir, relax
 from tvm.ir import assert_structural_equal
 
 import tvm.script
-from tvm.script import tir as T, relax as R
-
-
[email protected]_module
-class Before:
-    @T.prim_func
-    def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
-        m = T.int64()
-        n = T.int64()
-        k = T.int64()
-        A = T.match_buffer(x, (m, n))
-        B = T.match_buffer(y, (n, k))
-        C = T.match_buffer(z, (m, k))
-
-        for i, j, k in T.grid(m, k, n):
-            with T.block("matmul"):
-                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
-                with T.init():
-                    C[vi, vj] = T.float32(0)
-                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
-
-    @R.function
-    def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")) -> R.Tensor:
-        m, n, k = T.int64(), T.int64(), T.int64()
-        gv0 = R.call_tir(Before.tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
-        return gv0
+from tvm.script import tir as T, relax as R, ir as I
 
 
 def test_basic():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
+            A = T.match_buffer(x, (m, n))
+            B = T.match_buffer(y, (n, k))
+            C = T.match_buffer(z, (m, k))
+
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @R.function
+        def main(
+            x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")
+        ) -> R.Tensor:
+            m, n, k = T.int64(), T.int64(), T.int64()
+            gv0 = R.call_tir(Before.tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
+            return gv0
+
     @tvm.script.ir_module
     class Expected:
         @T.prim_func
@@ -84,5 +85,39 @@ def test_basic():
     assert_structural_equal(after, expected)
 
 
+def test_system_lib_prefix():
+    @tvm.script.ir_module
+    class Before:
+        I.module_attrs({"system_lib_prefix": "hello_"})
+
+        @T.prim_func
+        def tir_zeros(x: T.Buffer((2), "float32")) -> None:
+            x[0] = T.float32(0)
+
+        @R.function
+        def main() -> R.Tensor:
+            gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,), 
dtype="float32"))
+            return gv0
+
+    @tvm.script.ir_module
+    class Expected:
+        I.module_attrs({"system_lib_prefix": "hello_"})
+
+        @T.prim_func
+        def tir_zeros(x: T.Buffer((2), "float32")) -> None:
+            T.func_attr({"global_symbol": "hello_tir_zeros"})
+            x[0] = T.float32(0)
+
+        @R.function
+        def main() -> R.Tensor:
+            R.func_attr({"global_symbol": "main"})
+            gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), 
dtype="float32"))
+            return gv0
+
+    before = Before
+    after = relax.transform.AttachGlobalSymbol()(before)
+    assert_structural_equal(after, Expected)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index 9cf5445156..704fb923ad 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -15,19 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 import os
+import ctypes
 from typing import Tuple, Callable
 
-import sys
-import tempfile
+
 import numpy as np
 import pytest
 import tvm
 import tvm.script
 import tvm.testing
 from tvm import relax, rpc, te, tir, topi
-from tvm.contrib import utils
+from tvm.contrib import utils, cc, popen_pool
 from tvm.relax.testing import nn
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
 from tvm.relax.testing.vm import check_saved_func
 
 EXEC_MODE = ["bytecode", "compiled"]
@@ -724,6 +724,72 @@ class TestVMSetInput:
         return gv0
 
 
[email protected]("exec_mode", EXEC_MODE)
+def test_multi_systemlib(exec_mode):
+    @tvm.script.ir_module
+    class ModA:
+        I.module_attrs({"system_lib_prefix": "libA_"})
+
+        @T.prim_func
+        def tir_init(x: T.Buffer((2), "float32")) -> None:
+            for i in range(2):
+                x[i] = T.float32(0)
+
+        @R.function
+        def main(s: R.Shape(["m"])) -> R.Tensor:
+            m = T.int64()
+            gv0 = R.call_tir(ModA.tir_init, (), R.Tensor((m + 1,), 
dtype="float32"))
+            return gv0
+
+    @tvm.script.ir_module
+    class ModB:
+        I.module_attrs({"system_lib_prefix": "libB_"})
+
+        @T.prim_func
+        def tir_init(x: T.Buffer((2), "float32")) -> None:
+            for i in range(2):
+                x[i] = T.float32(1)
+
+        @R.function
+        def main(s: R.Shape(["m"])) -> R.Tensor:
+            m = T.int64()
+            gv0 = R.call_tir(ModB.tir_init, (), R.Tensor((m,), 
dtype="float32"))
+            return gv0
+
+    target = tvm.target.Target("llvm", host="llvm")
+    libA = relax.build(ModA, target, exec_mode=exec_mode)
+    libB = relax.build(ModB, target, exec_mode=exec_mode)
+
+    temp = utils.tempdir()
+    pathA = temp.relpath("libA.a")
+    pathB = temp.relpath("libB.a")
+    path_dso = temp.relpath("mylibAll.so")
+    libA.export_library(pathA, cc.create_staticlib)
+    libB.export_library(pathB, cc.create_staticlib)
+
+    # package two static libs together
+    # check that they do not interfere with each other
+    # even though they have shared global var names
+    # intentionally craft same gvar function with different behaviors
+    cc.create_shared(path_dso, ["-Wl,--whole-archive", pathA, pathB, 
"-Wl,--no-whole-archive"])
+
+    def popen_check():
+        # Load dll, will trigger system library registration
+        ctypes.CDLL(path_dso)
+        # Load the system wide library
+        vmA = relax.VirtualMachine(tvm.runtime.system_lib("libA_"), tvm.cpu())
+        vmB = relax.VirtualMachine(tvm.runtime.system_lib("libB_"), tvm.cpu())
+
+        retA = vmA["main"](tvm.runtime.ShapeTuple([1]))
+        retB = vmB["main"](tvm.runtime.ShapeTuple([2]))
+        np.testing.assert_equal(retA.numpy(), np.array([0, 
0]).astype("float32"))
+        np.testing.assert_equal(retB.numpy(), np.array([1, 
1]).astype("float32"))
+
+    # system lib should be loaded in different process
+    worker = popen_pool.PopenWorker()
+    worker.send(popen_check)
+
+
 def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> 
None:
     a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
     b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
@@ -779,13 +845,13 @@ def set_input_attempt_get(vm: relax.VirtualMachine, 
device: tvm.runtime.Device)
     _ = vm.get_outputs("main")
 
 
-def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]:
+def make_vm(mod, exec_mode, temp) -> Tuple[relax.VirtualMachine, 
tvm.runtime.Device]:
     """Returns a local VM for the given mod and the device"""
     target = tvm.target.Target("llvm", host="llvm")
     exec = relax.build(mod, target, exec_mode=exec_mode)
-    exec.export_library("exec.so")
-    exec_loaded = tvm.runtime.load_module("exec.so")
-    os.remove("exec.so")
+    libname = temp.relpath("exec.so")
+    exec.export_library(libname)
+    exec_loaded = tvm.runtime.load_module(libname)
     device = tvm.cpu()
     return relax.VirtualMachine(exec_loaded, device), device
 
@@ -827,19 +893,21 @@ def run_on_rpc(
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_set_input(exec_mode):
-    set_input_trial(*make_vm(TestVMSetInput, exec_mode))
+    temp = utils.tempdir()
+    set_input_trial(*make_vm(TestVMSetInput, exec_mode, temp))
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_set_input_tuple(exec_mode):
     @tvm.script.ir_module
-    class Module:
+    class MyMod:
         @R.function
         def main(x: R.Tuple([R.Tensor((32,), "float32"), R.Tensor((32,), 
"float32")])) -> R.Tensor:
             y = x[0]
             return y
 
-    vm, device = make_vm(Module, exec_mode)
+    temp = utils.tempdir()
+    vm, device = make_vm(MyMod, exec_mode, temp)
     device = tvm.cpu(0)
     a = tvm.nd.empty((32,), "float32", device=device)
     b = tvm.nd.empty((32,), "float32", device=device)
@@ -858,7 +926,8 @@ def save_function_kwargs_trial(vm: relax.VirtualMachine, 
device: tvm.runtime.Dev
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_save_function_kwargs(exec_mode):
-    save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode))
+    temp = utils.tempdir()
+    save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode, temp))
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -878,7 +947,8 @@ def save_function_time_evaluator_trial(
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_save_function_time_evaluator(exec_mode):
-    save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode))
+    temp = utils.tempdir()
+    save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode, 
temp))
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -890,7 +960,8 @@ def test_save_function_time_evaluator(exec_mode):
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 @pytest.mark.xfail()
 def test_set_input_stateless_failure(exec_mode):
-    set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode))
+    temp = utils.tempdir()
+    set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode, temp))
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -902,19 +973,22 @@ def test_set_input_stateless_failure_rpc(exec_mode):
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 @pytest.mark.xfail()
 def test_set_input_invoke_failure(exec_mode):
-    set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode))
+    temp = utils.tempdir()
+    set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode, temp))
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 @pytest.mark.xfail()
 def test_set_input_invoke_failure_rpc(exec_mode):
-    run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode)
+    temp = utils.tempdir()
+    run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode, temp)
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 @pytest.mark.xfail()
 def test_set_input_get_failure(exec_mode):
-    set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode))
+    temp = utils.tempdir()
+    set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode, temp))
 
 
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -924,4 +998,4 @@ def test_set_input_get_failure_rpc(exec_mode):
 
 
 if __name__ == "__main__":
-    tvm.testing.main()
+    pytest.main()

Reply via email to