This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4eafd00cad [Relax][Bugfix] FCallPacked not checked in CodegenVMTIR 
(#17073)
4eafd00cad is described below

commit 4eafd00cada11a03c2a949cc6fd0e5d9a06e013b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 6 09:46:00 2024 -0500

    [Relax][Bugfix] FCallPacked not checked in CodegenVMTIR (#17073)
    
    Prior to this commit, an operator's `FCallPacked` attribute, used to
    specify a 1:1 mapping between a relax operator and a `PackedFunc` that
    implements it, was only checked in `CodegenVM`.  Any operator with
    `FCallPacked` would raise an error when compiled using `CodegenVMTIR`.
    
    This commit removes the `FCallPacked` handling from `CodegenVM`
    altogether, and instead checks for this attribute as part of
    `LegalizeOps`.  This provides the same functionality across both
    backends.
---
 src/relax/backend/vm/codegen_vm.cc         |  24 +----
 src/relax/backend/vm/codegen_vm_tir.cc     |  24 +----
 src/relax/transform/legalize_ops.cc        |  25 ++++--
 tests/python/relax/test_relax_operators.py | 139 +++++++++++++++++------------
 4 files changed, 101 insertions(+), 111 deletions(-)

diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index 1c79559462..ca2d4d4fdb 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -45,21 +45,6 @@ using namespace relax;
 using namespace tvm::runtime;
 using namespace tvm::runtime::relax_vm;
 
-namespace {
-// Helper function to get the function name of the registered packed function 
implementation of
-// relax operator.
-FCallPacked GetPackedFuncName(const Call& call) {
-  static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
-  if (call->op.as<OpNode>()) {
-    Op op = Downcast<Op>(call->op);
-    if (op_map.count(op)) {
-      return op_map[op];
-    }
-  }
-  return {};
-}
-}  // namespace
-
 /*!
  * \brief A class to generate VM executable for Relax functions.
  */
@@ -156,14 +141,7 @@ class CodeGenVM : public 
ExprFunctor<Instruction::Arg(const Expr&)> {
     // allocate dst register.
     RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : 
NewRegister();
     if (call->op.as<OpNode>()) {
-      // special case generate for the intrinsics whose attribute fields
-      // cannot be represented by args in the CallNode
-      FCallPacked name = GetPackedFuncName(call);
-      if (!name.empty()) {
-        // If the operator has a registered packed function implementation, 
emit call to that packed
-        // function.
-        EmitPackedFuncCall(call, name, dst_reg);
-      } else if (call_node->op == call_builtin_with_ctx_op_) {
+      if (call_node->op == call_builtin_with_ctx_op_) {
         // TODO(relax-team) migrate most handling of op to
         // directly map to call_builtin_with_ctx before codegen and simplify 
vm codegen.
         EmitCallBuiltinWithCtx(call, dst_reg);
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc 
b/src/relax/backend/vm/codegen_vm_tir.cc
index 5e6a1c3f84..a92cf7c749 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -44,21 +44,6 @@ namespace relax_vm {
 
 using vm::VMFuncInfo;
 
-namespace {
-// Helper function to get the function name of the registered packed function 
implementation of
-// relax operator.
-FCallPacked GetPackedFuncName(const Call& call) {
-  static auto op_map = Op::GetAttrMap<FCallPacked>("FCallPacked");
-  if (call->op.as<OpNode>()) {
-    Op op = Downcast<Op>(call->op);
-    if (op_map.count(op)) {
-      return op_map[op];
-    }
-  }
-  return {};
-}
-}  // namespace
-
 /*!
  * \brief A class to generate VMTIR for Relax functions.
  *
@@ -247,14 +232,7 @@ class CodeGenVMTIR : public 
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
     }
     int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister();
     if (call->op.as<OpNode>()) {
-      // special case generate for the intrinsics whose attribute fields
-      // cannot be represented by args in the CallNode
-      FCallPacked name = GetPackedFuncName(call);
-      if (name.size()) {
-        // If the operator has a registered packed function implementation, 
emit call to that packed
-        // function.
-        EmitCallPacked(name, VisitArray(call->args), dst_reg);
-      } else if (call_node->op == call_builtin_with_ctx_op_) {
+      if (call_node->op == call_builtin_with_ctx_op_) {
         EmitCallBuiltinWithCtx(call, dst_reg);
       } else if (call_node->op == alloc_storage_op_) {
         EmitAllocStorage(call, dst_reg);
diff --git a/src/relax/transform/legalize_ops.cc 
b/src/relax/transform/legalize_ops.cc
index 34902fa0f8..4a6b44bf28 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -224,6 +224,7 @@ class LegalizeMutator : public ExprMutator {
   Expr VisitExpr_(const CallNode* call) final {
     Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
     static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+    static const auto& call_packed_map = 
Op::GetAttrMap<FCallPacked>("FCallPacked");
     static const auto& requires_arg_shapes_map = 
Op::GetAttrMap<Bool>("RequiresArgumentShapes");
     static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
     static const Op& call_tir_op = Op::Get("relax.call_tir");
@@ -236,7 +237,7 @@ class LegalizeMutator : public ExprMutator {
     }
     auto op = GetRef<Op>(op_node);
 
-    bool can_legalize = [&]() -> bool {
+    bool shapes_are_known_if_required = [&]() -> bool {
       bool requires_arg_shapes = requires_arg_shapes_map.get(op, 
Bool(true))->value;
       if (!requires_arg_shapes) {
         // This operator does not require its arguments to have a
@@ -299,23 +300,31 @@ class LegalizeMutator : public ExprMutator {
       return true;
     }();
 
-    if (!can_legalize) {
-      return visited_call;
-    }
-
     FLegalize legalization_func;
 
-    if (auto opt_custom_legalize = cmap_.Get(op->name)) {
+    if (auto opt_custom_legalize = cmap_.Get(op->name);
+        opt_custom_legalize && shapes_are_known_if_required) {
       // First choice, use a custom legalization function
       legalization_func = opt_custom_legalize.value();
-    } else if (legalize_map.count(op)) {
+    } else if (legalize_map.count(op) && shapes_are_known_if_required) {
       // Second choice, use a default legalization
       legalization_func = legalize_map[op];
+    } else if (call_packed_map.count(op)) {
+      // Third choice, use an explicit FCallPacked replacement.  This does not 
require the shape
+      String packed_func_name = call_packed_map[op];
+      legalization_func = [packed_func_name](const BlockBuilder& bb, const 
Call& call) -> Expr {
+        return Call(ExternFunc(packed_func_name), call->args, Attrs(), 
{GetStructInfo(call)});
+      };
     } else {
       // No legalization.
       if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
           op != call_pure_packed_op) {
-        LOG(WARNING) << "No legalization func for " << op->name << " is 
found.";
+        if (shapes_are_known_if_required) {
+          LOG(WARNING) << "No legalization func for " << op->name << " is 
found.";
+        } else {
+          LOG(WARNING) << "Cannot legalize " << visited_call
+                       << ", missing known shapes for arguments and return 
value";
+        }
       }
       return visited_call;
     }
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index 41618a32cb..fcb8727d85 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -27,6 +27,8 @@ from tvm import relax
 from tvm._ffi.base import TVMError
 from tvm.script import ir as I, relax as R, tir as T
 
+exec_mode = tvm.testing.parameter("bytecode", "compiled")
+
 
 @tvm.script.ir_module
 class InputModule:
@@ -37,7 +39,7 @@ class InputModule:
         return y, y_sorted
 
 
-def run_cpu(mod, func_name, *args):
+def run_cpu(mod, func_name, *args, exec_mode):
     if isinstance(mod, relax.Function):
         func = mod
         args = [func_name, *args]
@@ -45,17 +47,17 @@ def run_cpu(mod, func_name, *args):
         mod = tvm.IRModule.from_expr(func)
 
     target = tvm.target.Target("llvm")
-    ex = relax.build(mod, target)
+    ex = relax.build(mod, target, exec_mode=exec_mode)
     vm = relax.VirtualMachine(ex, tvm.cpu())
 
     return vm[func_name](*args)
 
 
-def test_unique():
+def test_unique(exec_mode):
     # TODO(prakalp): also add test for compiling and running on cuda device.
     data_numpy = np.random.randint(0, 16, (16, 16))
     data = tvm.nd.array(data_numpy)
-    result, result_sorted = run_cpu(InputModule, "foo", data)
+    result, result_sorted = run_cpu(InputModule, "foo", data, 
exec_mode=exec_mode)
 
     expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
     expected_output = [data_numpy.flatten()[index] for index in 
sorted(indices, reverse=True)]
@@ -81,12 +83,17 @@ class PrintTest:
         return x
 
 
-def test_print():
+def test_print(exec_mode):
     try:
         stdout = sys.stdout
         with tempfile.TemporaryFile(mode="w+") as test_out:
             sys.stdout = test_out
-            run_cpu(PrintTest, "foo", 
tvm.nd.array(np.array(1).astype("int32")))
+            run_cpu(
+                PrintTest,
+                "foo",
+                tvm.nd.array(np.array(1).astype("int32")),
+                exec_mode=exec_mode,
+            )
             test_out.seek(0)
             printed_text = str(test_out.read())
             expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 
1\nAnother print: 1 (1, 1)\n"
@@ -95,65 +102,65 @@ def test_print():
         sys.stdout = stdout
 
 
-def test_assert_passes():
+def test_assert_passes(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor((), "int32")):
         _ = R.assert_op(relax.const(True))
         return x
 
-    run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+    run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), 
exec_mode=exec_mode)
 
 
-def test_assert_passes_with_format_args():
+def test_assert_passes_with_format_args(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor((), "int32")):
         _ = R.assert_op(relax.const(True), x, format="You won't see me")
         return x
 
-    run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+    run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), 
exec_mode=exec_mode)
 
 
-def test_assert_fails():
+def test_assert_fails(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor((), "int32")):
         _ = R.assert_op(relax.const(False))
         return x
 
     with pytest.raises(AssertionError, match="Assertion Failed"):
-        run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+        run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), 
exec_mode=exec_mode)
 
 
-def test_assert_fails_with_message():
+def test_assert_fails_with_message(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor((), "int32")):
         _ = R.assert_op(relax.const(False), format="I failed...")
         return x
 
     with pytest.raises(AssertionError, match="I failed..."):
-        run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+        run_cpu(func, tvm.nd.array(np.array(1).astype("int32")), 
exec_mode=exec_mode)
 
 
-def test_assert_fails_with_args():
+def test_assert_fails_with_args(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor((), "int32")):
         _ = R.assert_op(relax.const(False), [x, x])
         return x
 
     with pytest.raises(AssertionError, match="5, 5"):
-        run_cpu(func, tvm.nd.array(np.array(5).astype("int32")))
+        run_cpu(func, tvm.nd.array(np.array(5).astype("int32")), 
exec_mode=exec_mode)
 
 
-def test_assert_fails_with_formatted_args():
+def test_assert_fails_with_formatted_args(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor((), "int32")):
         _ = R.assert_op(relax.const(False), x, format="Number: {}")
         return x
 
     with pytest.raises(AssertionError, match="Number: 6"):
-        run_cpu(func, tvm.nd.array(np.array(6).astype("int32")))
+        run_cpu(func, tvm.nd.array(np.array(6).astype("int32")), 
exec_mode=exec_mode)
 
 
-def test_assert_on_argument_passes():
+def test_assert_on_argument_passes(exec_mode):
     @R.function(pure=False)
     def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
         _ = R.assert_op(condition)
@@ -161,10 +168,10 @@ def test_assert_on_argument_passes():
 
     condition = tvm.nd.array(np.array(True))
     x = tvm.nd.array(np.array(5).astype("int32"))
-    run_cpu(func, condition, x)
+    run_cpu(func, condition, x, exec_mode=exec_mode)
 
 
-def test_assert_on_argument_fails():
+def test_assert_on_argument_fails(exec_mode):
     @R.function(pure=False)
     def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
         _ = R.assert_op(condition)
@@ -173,10 +180,10 @@ def test_assert_on_argument_fails():
     condition = tvm.nd.array(np.array(False))
     x = tvm.nd.array(np.array(5).astype("int32"))
     with pytest.raises(AssertionError):
-        run_cpu(func, condition, x)
+        run_cpu(func, condition, x, exec_mode=exec_mode)
 
 
-def test_assert_on_symbolic_var_passes():
+def test_assert_on_symbolic_var_passes(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor(["N"], "int32")):
         N = T.int64()
@@ -184,10 +191,10 @@ def test_assert_on_symbolic_var_passes():
         return x
 
     x = tvm.nd.array(np.arange(8, dtype="int32"))
-    run_cpu(func, x)
+    run_cpu(func, x, exec_mode=exec_mode)
 
 
-def test_assert_on_symbolic_var_fails():
+def test_assert_on_symbolic_var_fails(exec_mode):
     @R.function(pure=False)
     def func(x: R.Tensor(["N"], "int32")):
         N = T.int64()
@@ -196,7 +203,7 @@ def test_assert_on_symbolic_var_fails():
 
     x = tvm.nd.array(np.arange(10, dtype="int32"))
     with pytest.raises(AssertionError):
-        run_cpu(func, x)
+        run_cpu(func, x, exec_mode=exec_mode)
 
 
 @tvm.script.ir_module
@@ -223,23 +230,31 @@ class ShapeOfTest:
         return R.shape_of(x)
 
 
-def test_op_shape_of():
-    unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape")
+def test_op_shape_of(exec_mode):
+    unit_shape = run_cpu(ShapeOfTest, "get_scalar_shape", exec_mode=exec_mode)
     assert unit_shape == tvm.runtime.ShapeTuple([])
 
-    const_shape = run_cpu(ShapeOfTest, "get_constant_shape")
+    const_shape = run_cpu(ShapeOfTest, "get_constant_shape", 
exec_mode=exec_mode)
     assert const_shape == tvm.runtime.ShapeTuple([2, 2])
 
-    scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, 
dtype="int32")))
+    scalar_shape = run_cpu(
+        ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")), 
exec_mode=exec_mode
+    )
     assert scalar_shape == tvm.runtime.ShapeTuple([])
 
     tensor_shape = run_cpu(
-        ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 
3)).astype("int32"))
+        ShapeOfTest,
+        "get_shape",
+        tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")),
+        exec_mode=exec_mode,
     )
     assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])
 
     constrained_shape = run_cpu(
-        ShapeOfTest, "get_constrained_shape", 
tvm.nd.array(np.zeros((1,)).astype("int32"))
+        ShapeOfTest,
+        "get_constrained_shape",
+        tvm.nd.array(np.zeros((1,)).astype("int32")),
+        exec_mode=exec_mode,
     )
     assert constrained_shape == tvm.runtime.ShapeTuple([1])
 
@@ -257,7 +272,7 @@ class ShapeToTensorTest:
         return R.shape_to_tensor(shape)
 
 
-def test_op_shape_to_tensor():
+def test_op_shape_to_tensor(exec_mode):
     # Check struct info
     isinstance(ShapeToTensorTest["const_shape"].body.struct_info, 
tvm.relax.TensorStructInfo)
     assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1
@@ -265,24 +280,32 @@ def test_op_shape_to_tensor():
     assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1
 
     # Check its functionality
-    out2d = run_cpu(ShapeToTensorTest, "const_shape", 
tvm.runtime.ShapeTuple([3, 2]))
+    out2d = run_cpu(
+        ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 2]), 
exec_mode=exec_mode
+    )
     assert isinstance(out2d, tvm.runtime.ndarray.NDArray)
     assert np.array_equal(out2d.numpy(), np.array([3, 2]))
 
-    out3d = run_cpu(ShapeToTensorTest, "const_shape", 
tvm.runtime.ShapeTuple([3, 3, 2]))
+    out3d = run_cpu(
+        ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2]), 
exec_mode=exec_mode
+    )
     assert isinstance(out3d, tvm.runtime.ndarray.NDArray)
     assert np.array_equal(out3d.numpy(), np.array([3, 3, 2]))
 
-    out4d = run_cpu(ShapeToTensorTest, "const_shape", 
tvm.runtime.ShapeTuple([3, 3, 2, 2]))
+    out4d = run_cpu(
+        ShapeToTensorTest, "const_shape", tvm.runtime.ShapeTuple([3, 3, 2, 
2]), exec_mode=exec_mode
+    )
     assert isinstance(out4d, tvm.runtime.ndarray.NDArray)
     assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2]))
 
-    outs = run_cpu(ShapeToTensorTest, "symbolic_shape", 
tvm.runtime.ShapeTuple([3, 2]))
+    outs = run_cpu(
+        ShapeToTensorTest, "symbolic_shape", tvm.runtime.ShapeTuple([3, 2]), 
exec_mode=exec_mode
+    )
     assert isinstance(outs, tvm.runtime.ndarray.NDArray)
     assert np.array_equal(outs.numpy(), np.array([3, 2]))
 
 
-def test_op_call_pure_packed():
+def test_op_call_pure_packed(exec_mode):
     @tvm.script.ir_module
     class CallPureTest:
         @R.function
@@ -294,11 +317,11 @@ def test_op_call_pure_packed():
 
     np.random.seed(0)  # to avoid flakiness
     arr = np.random.rand(3, 4).astype("float32")
-    copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr))
+    copy_found = run_cpu(CallPureTest, "pure_copy", tvm.nd.array(arr), 
exec_mode=exec_mode)
     assert (copy_found.numpy() == arr).all()
 
 
-def test_op_call_inplace_packed():
+def test_op_call_inplace_packed(exec_mode):
     # in this case we can use the same test as above
     @tvm.script.ir_module
     class CallInplaceTest:
@@ -312,7 +335,7 @@ def test_op_call_inplace_packed():
             )
             return z
 
-    @tvm.register_func("test.inplace.add")
+    @tvm.register_func("test.inplace.add", override=True)
     def inplace_add(a, b):
         arr_a = a.numpy()
         arr_b = b.numpy()
@@ -340,11 +363,13 @@ def test_op_call_inplace_packed():
     arr_b = np.random.rand(3, 4).astype("float32")
     sum = arr_a + arr_b
     tvm_arr_a = tvm.nd.array(arr_a)
-    result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a, 
tvm.nd.array(arr_b))
+    result = run_cpu(
+        CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b), 
exec_mode=exec_mode
+    )
     assert result == tvm_arr_a
     assert (result.numpy() == sum).all()
 
-    @tvm.register_func("test.inplace.tuple_add")
+    @tvm.register_func("test.inplace.tuple_add", override=True)
     def inplace_tuple_add(a, b):
         arr_a = a.numpy()
         arr_b = b.numpy()
@@ -374,14 +399,14 @@ def test_op_call_inplace_packed():
     sum = arr_a + arr_b
     tvm_arr_a = tvm.nd.array(arr_a)
     tvm_arr_b = tvm.nd.array(arr_b)
-    result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b)
+    result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b, 
exec_mode=exec_mode)
     assert result[0] == tvm_arr_a
     assert (result[0].numpy() == sum).all()
     assert result[1] != tvm_arr_a and result[1] != tvm_arr_b
     assert (result[1].numpy() == sum).all()
 
 
-def test_op_to_device():
+def test_op_to_device(exec_mode):
     @tvm.script.ir_module
     class CallToDevice:
         @R.function
@@ -397,11 +422,11 @@ def test_op_to_device():
 
     np.random.seed(0)  # to avoid flakiness
     arr = np.random.rand(3, 4).astype("float32")
-    copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr))
+    copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr), 
exec_mode=exec_mode)
     assert (copy_found.numpy() == arr).all()
 
 
-def test_op_to_vdevice():
+def test_op_to_vdevice(exec_mode):
     @tvm.script.ir_module
     class ToVDevice:
         I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
@@ -414,11 +439,11 @@ def test_op_to_vdevice():
 
     np.random.seed(0)
     arr = np.random.rand(3, 4).astype("float32")
-    copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr))
+    copy_found = run_cpu(ToVDevice, "to_vdev", tvm.nd.array(arr), 
exec_mode=exec_mode)
     assert (copy_found.numpy() == arr).all()
 
 
-def test_scalar_tensor_as_branch_condition():
+def test_scalar_tensor_as_branch_condition(exec_mode):
     """The condition of a branch may be a scalar tensor"""
 
     @R.function
@@ -429,14 +454,14 @@ def test_scalar_tensor_as_branch_condition():
             out = R.prim_value(10)
         return out
 
-    res = run_cpu(func, tvm.nd.array(np.array(True)))
+    res = run_cpu(func, tvm.nd.array(np.array(True)), exec_mode=exec_mode)
     assert res == 5
 
-    res = run_cpu(func, tvm.nd.array(np.array(False)))
+    res = run_cpu(func, tvm.nd.array(np.array(False)), exec_mode=exec_mode)
     assert res == 10
 
 
-def test_prim_value_as_branch_condition():
+def test_prim_value_as_branch_condition(exec_mode):
     """The condition may be a PrimValue"""
 
     @R.function
@@ -447,14 +472,14 @@ def test_prim_value_as_branch_condition():
             out = R.prim_value(10)
         return out
 
-    res = run_cpu(func, True)
+    res = run_cpu(func, True, exec_mode=exec_mode)
     assert res == 5
 
-    res = run_cpu(func, False)
+    res = run_cpu(func, False, exec_mode=exec_mode)
     assert res == 10
 
 
-def test_computed_prim_value_as_branch_condition():
+def test_computed_prim_value_as_branch_condition(exec_mode):
     """The R.Prim condition may be computed within the function"""
 
     @R.function
@@ -466,10 +491,10 @@ def test_computed_prim_value_as_branch_condition():
             out = R.prim_value(10)
         return out
 
-    res = run_cpu(func, tvm.nd.array(np.arange(16)))
+    res = run_cpu(func, tvm.nd.array(np.arange(16)), exec_mode=exec_mode)
     assert res == 5
 
-    res = run_cpu(func, tvm.nd.array(np.arange(20)))
+    res = run_cpu(func, tvm.nd.array(np.arange(20)), exec_mode=exec_mode)
     assert res == 10
 
 

Reply via email to