This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4eafd00cad [Relax][Bugfix] FCallPacked not checked in CodegenVMTIR
(#17073)
4eafd00cad is described below
commit 4eafd00cada11a03c2a949cc6fd0e5d9a06e013b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 6 09:46:00 2024 -0500
[Relax][Bugfix] FCallPacked not checked in CodegenVMTIR (#17073)
Prior to this commit, an operator's `FCallPacked` attribute, used to
specify a 1:1 mapping between a relax operator and a `PackedFunc` that
implements it, was only checked in `CodegenVM`. Any operator with
`FCallPacked` would raise an error when compiled using `CodegenVMTIR`.
This commit removes the `FCallPacked` handling from `CodegenVM`
altogether, and instead checks for this attribute as part of
`LegalizeOps`. This provides the same functionality across both
backends.
---
src/relax/backend/vm/codegen_vm.cc | 24 +----
src/relax/backend/vm/codegen_vm_tir.cc | 24 +----
src/relax/transform/legalize_ops.cc | 25 ++++--
tests/python/relax/test_relax_operators.py | 139 +++++++++++++++++------------
4 files changed, 101 insertions(+), 111 deletions(-)
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 1c79559462..ca2d4d4fdb 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -45,21 +45,6 @@ using namespace relax;
using namespace tvm::runtime;
using namespace tvm::runtime::relax_vm;
-namespace {
-// Helper function to get the function name of the registered packed function
implementation of
-// relax operator.
-FCallPacked GetPackedFuncName(const Call& call) {
- static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
- if (call->op.as<OpNode>()) {
- Op op = Downcast<Op>(call->op);
- if (op_map.count(op)) {
- return op_map[op];
- }
- }
- return {};
-}
-} // namespace
-
/*!
* \brief A class to generate VM executable for Relax functions.
*/
@@ -156,14 +141,7 @@ class CodeGenVM : public
ExprFunctor<Instruction::Arg(const Expr&)> {
// allocate dst register.
RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister :
NewRegister();
if (call->op.as<OpNode>()) {
- // special case generate for the intrinsics whose attribute fields
- // cannot be represented by args in the CallNode
- FCallPacked name = GetPackedFuncName(call);
- if (!name.empty()) {
- // If the operator has a registered packed function implementation,
emit call to that packed
- // function.
- EmitPackedFuncCall(call, name, dst_reg);
- } else if (call_node->op == call_builtin_with_ctx_op_) {
+ if (call_node->op == call_builtin_with_ctx_op_) {
// TODO(relax-team) migrate most handling of op to
// directly map to call_builtin_with_ctx before codegen and simplify
vm codegen.
EmitCallBuiltinWithCtx(call, dst_reg);
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc
b/src/relax/backend/vm/codegen_vm_tir.cc
index 5e6a1c3f84..a92cf7c749 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -44,21 +44,6 @@ namespace relax_vm {
using vm::VMFuncInfo;
-namespace {
-// Helper function to get the function name of the registered packed function
implementation of
-// relax operator.
-FCallPacked GetPackedFuncName(const Call& call) {
- static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
- if (call->op.as<OpNode>()) {
- Op op = Downcast<Op>(call->op);
- if (op_map.count(op)) {
- return op_map[op];
- }
- }
- return {};
-}
-} // namespace
-
/*!
* \brief A class to generate VMTIR for Relax functions.
*
@@ -247,14 +232,7 @@ class CodeGenVMTIR : public
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
}
int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister();
if (call->op.as<OpNode>()) {
- // special case generate for the intrinsics whose attribute fields
- // cannot be represented by args in the CallNode
- FCallPacked name = GetPackedFuncName(call);
- if (name.size()) {
- // If the operator has a registered packed function implementation,
emit call to that packed
- // function.
- EmitCallPacked(name, VisitArray(call->args), dst_reg);
- } else if (call_node->op == call_builtin_with_ctx_op_) {
+ if (call_node->op == call_builtin_with_ctx_op_) {
EmitCallBuiltinWithCtx(call, dst_reg);
} else if (call_node->op == alloc_storage_op_) {
EmitAllocStorage(call, dst_reg);
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index 34902fa0f8..4a6b44bf28 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -224,6 +224,7 @@ class LegalizeMutator : public ExprMutator {
Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+ static const auto& call_packed_map =
Op::GetAttrMap<FCallPacked>("FCallPacked");
static const auto& requires_arg_shapes_map =
Op::GetAttrMap<Bool>("RequiresArgumentShapes");
static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
static const Op& call_tir_op = Op::Get("relax.call_tir");
@@ -236,7 +237,7 @@ class LegalizeMutator : public ExprMutator {
}
auto op = GetRef<Op>(op_node);
- bool can_legalize = [&]() -> bool {
+ bool shapes_are_known_if_required = [&]() -> bool {
bool requires_arg_shapes = requires_arg_shapes_map.get(op,
Bool(true))->value;
if (!requires_arg_shapes) {
// This operator does not require its arguments to have a
@@ -299,23 +300,31 @@ class LegalizeMutator : public ExprMutator {
return true;
}();
- if (!can_legalize) {
- return visited_call;
- }
-
FLegalize legalization_func;
- if (auto opt_custom_legalize = cmap_.Get(op->name)) {
+ if (auto opt_custom_legalize = cmap_.Get(op->name);
+ opt_custom_legalize && shapes_are_known_if_required) {
// First choice, use a custom legalization function
legalization_func = opt_custom_legalize.value();
- } else if (legalize_map.count(op)) {
+ } else if (legalize_map.count(op) && shapes_are_known_if_required) {
// Second choice, use a default legalization
legalization_func = legalize_map[op];
+ } else if (call_packed_map.count(op)) {
+ // Third choice, use an explicit FCallPacked replacement. This does not
require the shape
+ String packed_func_name = call_packed_map[op];
+ legalization_func = [packed_func_name](const BlockBuilder& bb, const
Call& call) -> Expr {
+ return Call(ExternFunc(packed_func_name), call->args, Attrs(),
{GetStructInfo(call)});
+ };
} else {
// No legalization.
if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
op != call_pure_packed_op) {
- LOG(WARNING) << "No legalization func for " << op->name << " is
found.";
+ if (shapes_are_known_if_required) {
+ LOG(WARNING) << "No legalization func for " << op->name << " is
found.";
+ } else {
+ LOG(WARNING) << "Cannot legalize " << visited_call
+ << ", missing known shapes for arguments and return
value";
+ }
}
return visited_call;
}
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 41618a32cb..fcb8727d85 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -27,6 +27,8 @@ from tvm import relax
from tvm._ffi.base import TVMError
from tvm.script import ir as I, relax as R, tir as T
+exec_mode = tvm.testing.parameter("bytecode", "compiled")
+
@tvm.script.ir_module
class InputModule:
@@ -37,7 +39,7 @@ class InputModule:
return y, y_sorted
-def run_cpu(mod, func_name, *args):
+def run_cpu(mod, func_name, *args, exec_mode):
if isinstance(mod, relax.Function):
func = mod
args = [func_name, *args]
@@ -45,17 +47,17 @@ def run_cpu(mod, func_name, *args):
mod = tvm.IRModule.from_expr(func)
target = tvm.target.Target("llvm")
- ex = relax.build(mod, target)
+ ex = relax.build(mod, target, exec_mode=exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())
return vm[func_name](*args)
-def test_unique():
+def test_unique(exec_mode):
# TODO(prakalp): also add test for compiling and running on cuda device.
data_numpy = np.random.randint(0, 16, (16, 16))
data = tvm.nd.array(data_numpy)
- result, result_sorted = run_cpu(InputModule, "foo", data)
+ result, result_sorted = run_cpu(InputModule, "foo", data,
exec_mode=exec_mode)
expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
expected_output = [data_numpy.flatten()[index] for index in
sorted(indices, reverse=True)]
@@ -81,12 +83,17 @@ class PrintTest:
return x
-def test_print():
+def test_print(exec_mode):
try:
stdout = sys.stdout
with tempfile.TemporaryFile(mode="w+") as test_out:
sys.stdout = test_out
- run_cpu(PrintTest, "foo",
tvm.nd.array(np.array(1).astype("int32")))
+ run_cpu(
+ PrintTest,
+ "foo",
+ tvm.nd.array(np.array(1).astype("int32")),
+ exec_mode=exec_mode,
+ )
test_out.seek(0)
printed_text = str(test_out.read())
expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1
1\nAnother print: 1 (1, 1)\n"
@@ -95,65 +102,65 @@ def test_print():
sys.stdout = stdout
-def test_assert_passes():
+def test_assert_passes(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor((), "int32")):
_ = R.assert_op(relax.const(True))
return x
- run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")),
exec_mode=exec_mode)
-def test_assert_passes_with_format_args():
+def test_assert_passes_with_format_args(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor((), "int32")):
_ = R.assert_op(relax.const(True), x, format="You won't see me")
return x
- run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")),
exec_mode=exec_mode)
-def test_assert_fails():
+def test_assert_fails(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor((), "int32")):
_ = R.assert_op(relax.const(False))
return x
with pytest.raises(AssertionError, match="Assertion Failed"):
- run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")),
exec_mode=exec_mode)
-def test_assert_fails_with_message():
+def test_assert_fails_with_message(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor((), "int32")):
_ = R.assert_op(relax.const(False), format="I failed...")
return x
with pytest.raises(AssertionError, match="I failed..."):
- run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")),
exec_mode=exec_mode)
-def test_assert_fails_with_args():
+def test_assert_fails_with_args(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor((), "int32")):
_ = R.assert_op(relax.const(False), [x, x])
return x
with pytest.raises(AssertionError, match="5, 5"):
- run_cpu(func, tvm.nd.array(np.array(5).astype("int32")))
+ run_cpu(func, tvm.nd.array(np.array(5).astype("int32")),
exec_mode=exec_mode)
-def test_assert_fails_with_formatted_args():
+def test_assert_fails_with_formatted_args(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor((), "int32")):
_ = R.assert_op(relax.const(False), x, format="Number: {}")
return x
with pytest.raises(AssertionError, match="Number: 6"):
- run_cpu(func, tvm.nd.array(np.array(6).astype("int32")))
+ run_cpu(func, tvm.nd.array(np.array(6).astype("int32")),
exec_mode=exec_mode)
-def test_assert_on_argument_passes():
+def test_assert_on_argument_passes(exec_mode):
@R.function(pure=False)
def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
_ = R.assert_op(condition)
@@ -161,10 +168,10 @@ def test_assert_on_argument_passes():
condition = tvm.nd.array(np.array(True))
x = tvm.nd.array(np.array(5).astype("int32"))
- run_cpu(func, condition, x)
+ run_cpu(func, condition, x, exec_mode=exec_mode)
-def test_assert_on_argument_fails():
+def test_assert_on_argument_fails(exec_mode):
@R.function(pure=False)
def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
_ = R.assert_op(condition)
@@ -173,10 +180,10 @@ def test_assert_on_argument_fails():
condition = tvm.nd.array(np.array(False))
x = tvm.nd.array(np.array(5).astype("int32"))
with pytest.raises(AssertionError):
- run_cpu(func, condition, x)
+ run_cpu(func, condition, x, exec_mode=exec_mode)
-def test_assert_on_symbolic_var_passes():
+def test_assert_on_symbolic_var_passes(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor(["N"], "int32")):
N = T.int64()
@@ -184,10 +191,10 @@ def test_assert_on_symbolic_var_passes():
return x
x = tvm.nd.array(np.arange(8, dtype="int32"))
- run_cpu(func, x)
+ run_cpu(func, x, exec_mode=exec_mode)
-def test_assert_on_symbolic_var_fails():
+def test_assert_on_symbolic_var_fails(exec_mode):
@R.function(pure=False)
def func(x: R.Tensor(["N"], "int32")):
N = T.int64()
@@ -196,7 +203,7 @@ def test_assert_on_symbolic_var_fails():
x = tvm.nd.array(np.arange(10, dtype="int32"))
with pytest.raises(AssertionError):
- run_cpu(func, x)
+ run_cpu(func, x, exec_mode=exec_mode)
@tvm.script.ir_module
@@ -223,23 +230,31 @@ class ShapeOfTest:
return R.shape_of(x)
-def test_op_shape_of():
- unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape")
+def test_op_shape_of(exec_mode):
+ unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape", exec_mode=exec_mode)
assert unit_shape == tvm.runtime.ShapeTuple([])
- const_shape = run_cpu(ShapeOfTest, "get_constant_shape")
+ const_shape = run_cpu(ShapeOfTest, "get_constant_shape",
exec_mode=exec_mode)
assert const_shape == tvm.runtime.ShapeTuple([2, 2])
- scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1,
dtype="int32")))
+ scalar_shape = run_cpu(
+ ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")),
exec_mode=exec_mode
+ )
assert scalar_shape == tvm.runtime.ShapeTuple([])
tensor_shape = run_cpu(
- ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2,
3)).astype("int32"))
+ ShapeOfTest,
+ "get_shape",
+ tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")),
+ exec_mode=exec_mode,
)
assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])
constrained_shape = run_cpu(
- ShapeOfTest, "get_constrained_shape",
tvm.nd.array(np.zeros((1,)).astype("int32"))
+ ShapeOfTest,
+ "get_constrained_shape",
+ tvm.nd.array(np.zeros((1,)).astype("int32")),
+ exec_mode=exec_mode,
)
assert constrained_shape == tvm.runtime.ShapeTuple([1])
@@ -257,7 +272,7 @@ class ShapeToTensorTest:
return R.shape_to_tensor(shape)
-def test_op_shape_to_tensor():
+def test_op_shape_to_tensor(exec_mode):
# Check struct info
isinstance(ShapeToTensorTest["const_shape"].body.struct_info,
tvm.relax.TensorStructInfo)
assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1
@@ -265,24 +280,32 @@ def test_op_shape_to_tensor():
assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1
# Check its functionality
- out2d = run_cpu(ShapeToTensorTest, "const_shape",
tvm.runtime.ShapeTuple([3, 2]))
+ out2d = run_cpu(
+ ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]),
exec_mode=exec_mode
+ )
assert isinstance(out2d, tvm.runtime.ndarray.NDArray)
assert np.array_equal(out2d.numpy(), np.array([3, 2]))
- out3d = run_cpu(ShapeToTensorTest, "const_shape",
tvm.runtime.ShapeTuple([3, 3, 2]))
+ out3d = run_cpu(
+ ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]),
exec_mode=exec_mode
+ )
assert isinstance(out3d, tvm.runtime.ndarray.NDArray)
assert np.array_equal(out3d.numpy(), np.array([3, 3, 2]))
- out4d = run_cpu(ShapeToTensorTest, "const_shape",
tvm.runtime.ShapeTuple([3, 3, 2, 2]))
+ out4d = run_cpu(
+ ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2,
2]), exec_mode=exec_mode
+ )
assert isinstance(out4d, tvm.runtime.ndarray.NDArray)
assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2]))
- outs = run_cpu(ShapeToTensorTest, "symbolic_shape",
tvm.runtime.ShapeTuple([3, 2]))
+ outs = run_cpu(
+ ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]),
exec_mode=exec_mode
+ )
assert isinstance(outs, tvm.runtime.ndarray.NDArray)
assert np.array_equal(outs.numpy(), np.array([3, 2]))
-def test_op_call_pure_packed():
+def test_op_call_pure_packed(exec_mode):
@tvm.script.ir_module
class CallPureTest:
@R.function
@@ -294,11 +317,11 @@ def test_op_call_pure_packed():
np.random.seed(0) # to avoid flakiness
arr = np.random.rand(3, 4).astype("float32")
- copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr))
+ copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr),
exec_mode=exec_mode)
assert (copy_found.numpy() == arr).all()
-def test_op_call_inplace_packed():
+def test_op_call_inplace_packed(exec_mode):
# in this case we can use the same test as above
@tvm.script.ir_module
class CallInplaceTest:
@@ -312,7 +335,7 @@ def test_op_call_inplace_packed():
)
return z
- @tvm.register_func("test.inplace.add")
+ @tvm.register_func("test.inplace.add", override=True)
def inplace_add(a, b):
arr_a = a.numpy()
arr_b = b.numpy()
@@ -340,11 +363,13 @@ def test_op_call_inplace_packed():
arr_b = np.random.rand(3, 4).astype("float32")
sum = arr_a + arr_b
tvm_arr_a = tvm.nd.array(arr_a)
- result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a,
tvm.nd.array(arr_b))
+ result = run_cpu(
+ CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b),
exec_mode=exec_mode
+ )
assert result == tvm_arr_a
assert (result.numpy() == sum).all()
- @tvm.register_func("test.inplace.tuple_add")
+ @tvm.register_func("test.inplace.tuple_add", override=True)
def inplace_tuple_add(a, b):
arr_a = a.numpy()
arr_b = b.numpy()
@@ -374,14 +399,14 @@ def test_op_call_inplace_packed():
sum = arr_a + arr_b
tvm_arr_a = tvm.nd.array(arr_a)
tvm_arr_b = tvm.nd.array(arr_b)
- result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b)
+ result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b,
exec_mode=exec_mode)
assert result[0] == tvm_arr_a
assert (result[0].numpy() == sum).all()
assert result[1] != tvm_arr_a and result[1] != tvm_arr_b
assert (result[1].numpy() == sum).all()
-def test_op_to_device():
+def test_op_to_device(exec_mode):
@tvm.script.ir_module
class CallToDevice:
@R.function
@@ -397,11 +422,11 @@ def test_op_to_device():
np.random.seed(0) # to avoid flakiness
arr = np.random.rand(3, 4).astype("float32")
- copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr))
+ copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr),
exec_mode=exec_mode)
assert (copy_found.numpy() == arr).all()
-def test_op_to_vdevice():
+def test_op_to_vdevice(exec_mode):
@tvm.script.ir_module
class ToVDevice:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
@@ -414,11 +439,11 @@ def test_op_to_vdevice():
np.random.seed(0)
arr = np.random.rand(3, 4).astype("float32")
- copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr))
+ copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr),
exec_mode=exec_mode)
assert (copy_found.numpy() == arr).all()
-def test_scalar_tensor_as_branch_condition():
+def test_scalar_tensor_as_branch_condition(exec_mode):
"""The condition of a branch may be a scalar tensor"""
@R.function
@@ -429,14 +454,14 @@ def test_scalar_tensor_as_branch_condition():
out = R.prim_value(10)
return out
- res = run_cpu(func, tvm.nd.array(np.array(True)))
+ res = run_cpu(func, tvm.nd.array(np.array(True)), exec_mode=exec_mode)
assert res == 5
- res = run_cpu(func, tvm.nd.array(np.array(False)))
+ res = run_cpu(func, tvm.nd.array(np.array(False)), exec_mode=exec_mode)
assert res == 10
-def test_prim_value_as_branch_condition():
+def test_prim_value_as_branch_condition(exec_mode):
"""The condition may be a PrimValue"""
@R.function
@@ -447,14 +472,14 @@ def test_prim_value_as_branch_condition():
out = R.prim_value(10)
return out
- res = run_cpu(func, True)
+ res = run_cpu(func, True, exec_mode=exec_mode)
assert res == 5
- res = run_cpu(func, False)
+ res = run_cpu(func, False, exec_mode=exec_mode)
assert res == 10
-def test_computed_prim_value_as_branch_condition():
+def test_computed_prim_value_as_branch_condition(exec_mode):
"""The R.Prim condition may be computed within the function"""
@R.function
@@ -466,10 +491,10 @@ def test_computed_prim_value_as_branch_condition():
out = R.prim_value(10)
return out
- res = run_cpu(func, tvm.nd.array(np.arange(16)))
+ res = run_cpu(func, tvm.nd.array(np.arange(16)), exec_mode=exec_mode)
assert res == 5
- res = run_cpu(func, tvm.nd.array(np.arange(20)))
+ res = run_cpu(func, tvm.nd.array(np.arange(20)), exec_mode=exec_mode)
assert res == 10