This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cb39cc1a97 [Unity] Support multilib relax build (#14873)
cb39cc1a97 is described below
commit cb39cc1a97d4f1650795a77de3a96c9aa93769ad
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 18 03:46:08 2023 -0400
[Unity] Support multilib relax build (#14873)
* [Unity] AttachGlobalSymbol support multilib
This PR makes AttachGlobalSymbol to be aware of
system_lib_prefix and enable multi-lib support.
* Fix conflicting temp concurrent access
* Temp enable a legacy behavior that impacts MS relax integration
---
python/tvm/relax/vm_build.py | 1 +
src/relax/backend/vm/codegen_vm.cc | 6 +-
src/relax/backend/vm/codegen_vm_tir.cc | 8 +-
src/relax/transform/attach_global_symbol.cc | 48 +++++----
.../relax/test_transform_attach_global_symbol.py | 87 +++++++++++-----
tests/python/relax/test_vm_build.py | 110 +++++++++++++++++----
6 files changed, 185 insertions(+), 75 deletions(-)
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index a833939833..7e4fd85987 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -339,6 +339,7 @@ def build(
def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule:
tir_mod = IRModule({})
+ tir_mod = tir_mod.with_attrs(mod.attrs)
for gv in mod.get_global_vars():
if isinstance(mod[gv], PrimFunc):
tir_mod[gv] = mod[gv]
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index c44300907f..3fbe246cd3 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -67,14 +67,14 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const
Expr&)> {
: builder_(builder), ctx_mod_(ctx_mod) {}
static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
- IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>());
+ IRModule res_mod = mod;
+ res_mod.CopyOnWrite();
CodeGenVM codegen(builder, mod);
// Remove relax function and turn into TIR func.
for (const auto& [gvar, f] : mod->functions) {
if (auto* func = f.as<FunctionNode>()) {
codegen.Codegen(GetRef<Function>(func));
- } else {
- res_mod->Add(gvar, f);
+ res_mod->Remove(gvar);
}
}
return res_mod;
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc
b/src/relax/backend/vm/codegen_vm_tir.cc
index 276632a917..9ac65f6f6e 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -53,7 +53,9 @@ using vm::VMFuncInfo;
class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
public:
explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod)
- : builder_(builder), ctx_mod_(ctx_mod) {}
+ : builder_(builder), ctx_mod_(ctx_mod) {
+ system_lib_prefix_ =
ctx_mod_->GetAttr<String>(tvm::attr::kSystemLibPrefix);
+ }
static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
// create a new copy
@@ -189,7 +191,7 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
Type ret_type = VoidType();
Array<tir::Var> tir_params = {ctx_ptr_, reg_anylist_handle_,
const_anylist_handle_,
func_anylist_handle_};
- String tir_func_name = "__vmtir__" + gsymbol.value();
+ String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" +
gsymbol.value();
tir::PrimFunc tir_func(tir_params, body, ret_type, {});
tir_func = WithAttr(tir_func, "global_symbol", tir_func_name);
registers_num_ = 0;
@@ -506,6 +508,8 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
std::unordered_map<Var, Optional<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
var_map_;
/*! \brief the context module. */
IRModule ctx_mod_;
+ /*! \brief system lib prefix */
+ Optional<String> system_lib_prefix_;
/*! \brief Cache ops that need to be frequently used later to reduce lookup
overhead. */
const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
diff --git a/src/relax/transform/attach_global_symbol.cc
b/src/relax/transform/attach_global_symbol.cc
index be779e97bc..9b2a561c7f 100644
--- a/src/relax/transform/attach_global_symbol.cc
+++ b/src/relax/transform/attach_global_symbol.cc
@@ -21,48 +21,44 @@
* \brief Attach global_symbol to Relax functions and TIR Primfuncs for
codegen.
*/
+#include <tvm/ir/module.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/function.h>
namespace tvm {
namespace relax {
+namespace transform {
+
+Pass AttachGlobalSymbol() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
+
PassContext pc) {
+ mod.CopyOnWrite();
-class GlobalSymbolAttacher {
- public:
- explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {}
+ String c_prefix =
mod->GetAttr<String>(tvm::attr::kSystemLibPrefix).value_or("");
+ std::vector<std::pair<GlobalVar, BaseFunc> > updates;
- IRModule Attach() {
- IRModule ret;
- for (auto& p : mod_->functions) {
+ for (auto& p : mod->functions) {
BaseFunc func = p.second;
+ // TODO(tvm-team): re-enable once fix relax integration part
+ // if (func->GetAttr<String>(tvm::attr::kGlobalSymbol)) continue;
if (auto* prim_func = func.as<tir::PrimFuncNode>()) {
- func = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
p.first->name_hint);
+ updates.emplace_back(p.first,
+ WithAttr(GetRef<tir::PrimFunc>(prim_func),
tvm::attr::kGlobalSymbol,
+ c_prefix + p.first->name_hint));
} else if (auto* relax_func = func.as<FunctionNode>()) {
- func = WithAttr(GetRef<Function>(relax_func), "global_symbol",
p.first->name_hint);
- } else {
- LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey();
- throw;
+ updates.emplace_back(p.first, WithAttr(GetRef<Function>(relax_func),
+ tvm::attr::kGlobalSymbol,
p.first->name_hint));
}
- ret->Add(p.first, func);
}
- return ret;
- }
-
- private:
- IRModule mod_;
-};
-
-namespace transform {
-
-Pass AttachGlobalSymbol() {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule mod, PassContext pc) { return
GlobalSymbolAttacher(mod).Attach(); };
+ for (const auto& pair : updates) {
+ mod->Add(pair.first, pair.second, true);
+ }
+ return mod;
+ };
return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {});
}
TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol);
-
} // namespace transform
-
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py
b/tests/python/relax/test_transform_attach_global_symbol.py
index 0937eed25a..035e21609d 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -21,35 +21,36 @@ from tvm import tir, relax
from tvm.ir import assert_structural_equal
import tvm.script
-from tvm.script import tir as T, relax as R
-
-
[email protected]_module
-class Before:
- @T.prim_func
- def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
- m = T.int64()
- n = T.int64()
- k = T.int64()
- A = T.match_buffer(x, (m, n))
- B = T.match_buffer(y, (n, k))
- C = T.match_buffer(z, (m, k))
-
- for i, j, k in T.grid(m, k, n):
- with T.block("matmul"):
- vi, vj, vk = T.axis.remap("SSR", [i, j, k])
- with T.init():
- C[vi, vj] = T.float32(0)
- C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
-
- @R.function
- def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")) -> R.Tensor:
- m, n, k = T.int64(), T.int64(), T.int64()
- gv0 = R.call_tir(Before.tir_matmul, (x, w), R.Tensor((m, k),
dtype="float32"))
- return gv0
+from tvm.script import tir as T, relax as R, ir as I
def test_basic():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ m = T.int64()
+ n = T.int64()
+ k = T.int64()
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @R.function
+ def main(
+ x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")
+ ) -> R.Tensor:
+ m, n, k = T.int64(), T.int64(), T.int64()
+ gv0 = R.call_tir(Before.tir_matmul, (x, w), R.Tensor((m, k),
dtype="float32"))
+ return gv0
+
@tvm.script.ir_module
class Expected:
@T.prim_func
@@ -84,5 +85,39 @@ def test_basic():
assert_structural_equal(after, expected)
+def test_system_lib_prefix():
+ @tvm.script.ir_module
+ class Before:
+ I.module_attrs({"system_lib_prefix": "hello_"})
+
+ @T.prim_func
+ def tir_zeros(x: T.Buffer((2), "float32")) -> None:
+ x[0] = T.float32(0)
+
+ @R.function
+ def main() -> R.Tensor:
+ gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,),
dtype="float32"))
+ return gv0
+
+ @tvm.script.ir_module
+ class Expected:
+ I.module_attrs({"system_lib_prefix": "hello_"})
+
+ @T.prim_func
+ def tir_zeros(x: T.Buffer((2), "float32")) -> None:
+ T.func_attr({"global_symbol": "hello_tir_zeros"})
+ x[0] = T.float32(0)
+
+ @R.function
+ def main() -> R.Tensor:
+ R.func_attr({"global_symbol": "main"})
+ gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,),
dtype="float32"))
+ return gv0
+
+ before = Before
+ after = relax.transform.AttachGlobalSymbol()(before)
+ assert_structural_equal(after, Expected)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index 9cf5445156..704fb923ad 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -15,19 +15,19 @@
# specific language governing permissions and limitations
# under the License.
import os
+import ctypes
from typing import Tuple, Callable
-import sys
-import tempfile
+
import numpy as np
import pytest
import tvm
import tvm.script
import tvm.testing
from tvm import relax, rpc, te, tir, topi
-from tvm.contrib import utils
+from tvm.contrib import utils, cc, popen_pool
from tvm.relax.testing import nn
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
from tvm.relax.testing.vm import check_saved_func
EXEC_MODE = ["bytecode", "compiled"]
@@ -724,6 +724,72 @@ class TestVMSetInput:
return gv0
[email protected]("exec_mode", EXEC_MODE)
+def test_multi_systemlib(exec_mode):
+ @tvm.script.ir_module
+ class ModA:
+ I.module_attrs({"system_lib_prefix": "libA_"})
+
+ @T.prim_func
+ def tir_init(x: T.Buffer((2), "float32")) -> None:
+ for i in range(2):
+ x[i] = T.float32(0)
+
+ @R.function
+ def main(s: R.Shape(["m"])) -> R.Tensor:
+ m = T.int64()
+ gv0 = R.call_tir(ModA.tir_init, (), R.Tensor((m + 1,),
dtype="float32"))
+ return gv0
+
+ @tvm.script.ir_module
+ class ModB:
+ I.module_attrs({"system_lib_prefix": "libB_"})
+
+ @T.prim_func
+ def tir_init(x: T.Buffer((2), "float32")) -> None:
+ for i in range(2):
+ x[i] = T.float32(1)
+
+ @R.function
+ def main(s: R.Shape(["m"])) -> R.Tensor:
+ m = T.int64()
+ gv0 = R.call_tir(ModB.tir_init, (), R.Tensor((m,),
dtype="float32"))
+ return gv0
+
+ target = tvm.target.Target("llvm", host="llvm")
+ libA = relax.build(ModA, target, exec_mode=exec_mode)
+ libB = relax.build(ModB, target, exec_mode=exec_mode)
+
+ temp = utils.tempdir()
+ pathA = temp.relpath("libA.a")
+ pathB = temp.relpath("libB.a")
+ path_dso = temp.relpath("mylibAll.so")
+ libA.export_library(pathA, cc.create_staticlib)
+ libB.export_library(pathB, cc.create_staticlib)
+
+ # package two static libs together
+ # check that they do not interfere with each other
+ # even though they have shared global var names
+ # intentionally craft same gvar function with different behaviors
+ cc.create_shared(path_dso, ["-Wl,--whole-archive", pathA, pathB,
"-Wl,--no-whole-archive"])
+
+ def popen_check():
+ # Load dll, will trigger system library registration
+ ctypes.CDLL(path_dso)
+ # Load the system wide library
+ vmA = relax.VirtualMachine(tvm.runtime.system_lib("libA_"), tvm.cpu())
+ vmB = relax.VirtualMachine(tvm.runtime.system_lib("libB_"), tvm.cpu())
+
+ retA = vmA["main"](tvm.runtime.ShapeTuple([1]))
+ retB = vmB["main"](tvm.runtime.ShapeTuple([2]))
+ np.testing.assert_equal(retA.numpy(), np.array([0,
0]).astype("float32"))
+ np.testing.assert_equal(retB.numpy(), np.array([1,
1]).astype("float32"))
+
+ # system lib should be loaded in different process
+ worker = popen_pool.PopenWorker()
+ worker.send(popen_check)
+
+
def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) ->
None:
a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
@@ -779,13 +845,13 @@ def set_input_attempt_get(vm: relax.VirtualMachine,
device: tvm.runtime.Device)
_ = vm.get_outputs("main")
-def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]:
+def make_vm(mod, exec_mode, temp) -> Tuple[relax.VirtualMachine,
tvm.runtime.Device]:
"""Returns a local VM for the given mod and the device"""
target = tvm.target.Target("llvm", host="llvm")
exec = relax.build(mod, target, exec_mode=exec_mode)
- exec.export_library("exec.so")
- exec_loaded = tvm.runtime.load_module("exec.so")
- os.remove("exec.so")
+ libname = temp.relpath("exec.so")
+ exec.export_library(libname)
+ exec_loaded = tvm.runtime.load_module(libname)
device = tvm.cpu()
return relax.VirtualMachine(exec_loaded, device), device
@@ -827,19 +893,21 @@ def run_on_rpc(
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_set_input(exec_mode):
- set_input_trial(*make_vm(TestVMSetInput, exec_mode))
+ temp = utils.tempdir()
+ set_input_trial(*make_vm(TestVMSetInput, exec_mode, temp))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_set_input_tuple(exec_mode):
@tvm.script.ir_module
- class Module:
+ class MyMod:
@R.function
def main(x: R.Tuple([R.Tensor((32,), "float32"), R.Tensor((32,),
"float32")])) -> R.Tensor:
y = x[0]
return y
- vm, device = make_vm(Module, exec_mode)
+ temp = utils.tempdir()
+ vm, device = make_vm(MyMod, exec_mode, temp)
device = tvm.cpu(0)
a = tvm.nd.empty((32,), "float32", device=device)
b = tvm.nd.empty((32,), "float32", device=device)
@@ -858,7 +926,8 @@ def save_function_kwargs_trial(vm: relax.VirtualMachine,
device: tvm.runtime.Dev
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_save_function_kwargs(exec_mode):
- save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode))
+ temp = utils.tempdir()
+ save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode, temp))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -878,7 +947,8 @@ def save_function_time_evaluator_trial(
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_save_function_time_evaluator(exec_mode):
- save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode))
+ temp = utils.tempdir()
+ save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode,
temp))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -890,7 +960,8 @@ def test_save_function_time_evaluator(exec_mode):
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@pytest.mark.xfail()
def test_set_input_stateless_failure(exec_mode):
- set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode))
+ temp = utils.tempdir()
+ set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode, temp))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -902,19 +973,22 @@ def test_set_input_stateless_failure_rpc(exec_mode):
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@pytest.mark.xfail()
def test_set_input_invoke_failure(exec_mode):
- set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode))
+ temp = utils.tempdir()
+ set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode, temp))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@pytest.mark.xfail()
def test_set_input_invoke_failure_rpc(exec_mode):
- run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode)
+ temp = utils.tempdir()
+ run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode, temp)
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@pytest.mark.xfail()
def test_set_input_get_failure(exec_mode):
- set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode))
+ temp = utils.tempdir()
+ set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode, temp))
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
@@ -924,4 +998,4 @@ def test_set_input_get_failure_rpc(exec_mode):
if __name__ == "__main__":
- tvm.testing.main()
+ pytest.main()