This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 36e3c121b7 [Relax] Validate StructInfo annotations in well-formed 
check (#17331)
36e3c121b7 is described below

commit 36e3c121b7dcfae3d5d5098186a7ca96e7ff27fc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 19 12:28:25 2024 -0500

    [Relax] Validate StructInfo annotations in well-formed check (#17331)
    
    * [Relax] Validate StructInfo annotations in well-formed check
    
    Prior to this commit, the Relax well-formed checker verified that each
    expression had a non-null `StructInfo` annotation, but did not perform
    any validation on the contents of the `StructInfo` annotation.
    
    This commit updates the Relax well-formed check to verify that the
    `StructInfo` annotations are accurate by comparing against the
    `StructInfo` that would be inferred for an expression.  (This only
    requires that the information is accurate, not that it is complete.
    For example, an expression that is inferred to be
    `R.Tensor(shape=[128,8], dtype="float32")` may have annotation of
    `R.Tensor(ndim=2, dtype="float32"`, but may not have an annotation of
    `R.Tensor(shape=[128,8], dtype="int32")`.)
    
    * lint fix
    
    * lint fix
---
 src/relax/analysis/well_formed.cc                  | 43 +++++++++++
 src/relax/op/op.cc                                 | 21 ++++--
 tests/python/relax/test_analysis_well_formed.py    | 85 ++++++++++++++++++++++
 tests/python/relax/test_ast_printer.py             |  4 +-
 tests/python/relax/test_frontend_from_fx.py        | 10 +--
 tests/python/relax/test_transform_decompose_ops.py |  4 +-
 .../relax/test_transform_ipc_allreduce_rewrite.py  |  4 +-
 .../relax/test_transform_legalize_ops_ccl.py       |  4 +-
 .../test_transform_legalize_ops_create_datatype.py | 34 ++++-----
 ..._transform_legalize_ops_index_linear_algebra.py |  2 +-
 .../test_transform_legalize_ops_manipulate.py      | 51 +++++++------
 .../python/relax/test_transform_legalize_ops_nn.py | 38 ++++++----
 ...st_transform_legalize_ops_search_statistical.py |  4 +-
 .../python/relax/test_transform_realize_vdevice.py | 16 ++--
 .../test_transform_static_plan_block_memory.py     |  8 +-
 .../relax/test_transform_to_mixed_precision.py     | 12 +--
 tests/python/relax/test_tvmscript_parser.py        | 10 +--
 tests/python/relax/test_vm_cuda_graph.py           |  8 +-
 tests/python/relax/test_vm_multi_device.py         | 14 ++--
 19 files changed, 268 insertions(+), 104 deletions(-)

diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 7688c4a642..7873d5ce20 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor,
                                           << err.what());
       }
     }
+
+    if (check_struct_info_ && call->struct_info_.defined()) {
+      // The `InferStructInfo` method isn't currently exposed by the
+      // Normalizer, and can only be called indirectly by normalizing
+      // an expression that does not yet have `StructInfo`.
+      auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
+      Call copied(call->op, call->args, call->attrs, call->sinfo_args);
+      Optional<Expr> normalized = NullOpt;
+      try {
+        normalized = dummy_builder->Normalize(copied);
+      } catch (std::exception& err) {
+        Malformed(Diagnostic::Error(call)
+                  << "Each Relax expression must be able to have its 
StructInfo inferred.  "
+                  << "However, inferring the struct info of expression " << 
GetRef<Call>(call)
+                  << " resulted in the error: \n"
+                  << err.what());
+      }
+      if (normalized.defined()) {
+        auto inferred_struct_info = GetStructInfo(normalized.value());
+        auto current_struct_info = Downcast<StructInfo>(call->struct_info_);
+
+        // An error should be raised if the annotated StructInfo is
+        // provably incorrect.  This check is done using
+        // `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1`
+        // represents cases that are neither provably correct nor
+        // provably incorrect.  If this check were replaced with
+        // `!IsBaseOf(...)`, cases that are correct but not provably
+        // so would raise an exception.
+        //
+        // For example, if a dynamic size in the inferred StructInfo
+        // is equivalent to the expression used in the annotated
+        // StructInfo, but the TIR simplifications are not sufficient
+        // to prove that the two expressions are equivalent, we should
+        // not raise an error.
+        if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) <
+            BaseCheckResult::kFailL1) {
+          Malformed(Diagnostic::Error(call)
+                    << "All information in StructInfo annotations must be 
correct.  "
+                    << "However, while the expression " << GetRef<Call>(call) 
<< " is annotated as "
+                    << current_struct_info << ", the expression outputs " << 
inferred_struct_info);
+        }
+      }
+    }
   }
 
   void VisitExpr_(const IfNode* op) final {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 3e0f0eba31..a7d97a59a1 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -1021,14 +1021,19 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& 
call, const BlockBuilder& c
   ICHECK(call->args.size() == 1);
   ICHECK(call->args[0]->struct_info_.defined());
   const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
-  ICHECK(tsinfo && tsinfo->shape.defined());
-  ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
-  ICHECK(shape_expr->values.size() == 1) << "relax.tensor_to_shape expected 
argument to be 1-d, "
-                                         << "but " << call << " has argument " 
<< call->args[0]
-                                         << " with struct info " << 
call->args[0]->struct_info_;
-  const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
-  ICHECK(ndim);
-  return ShapeStructInfo(ndim->value);
+  ICHECK(tsinfo);
+  ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 
1-d, "
+                             << "but " << call << " has argument " << 
call->args[0]
+                             << " with struct info " << 
call->args[0]->struct_info_;
+
+  if (tsinfo->shape.defined()) {
+    ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
+    const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
+    if (ndim) {
+      return ShapeStructInfo(ndim->value);
+    }
+  }
+  return ShapeStructInfo(kUnknownNDim);
 }
 
 RELAY_REGISTER_OP("relax.tensor_to_shape")
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index 3db3efee1a..d9eefcfd0e 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -1295,5 +1295,90 @@ def 
test_var_binding_with_incomplete_struct_info_must_be_consistent():
     assert not rx.analysis.well_formed(main)
 
 
+def test_incomplete_struct_info_must_be_consistent():
+    """StructInfo annotations must be accurate
+
+    Even though StructInfo annotation may be less specific, the
+    information that they do contain must be correct.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(
+            A: R.Tensor(shape=[128, 32], dtype="float32"),
+            B: R.Tensor(shape=[128, 32], dtype="float32"),
+        ):
+            C: R.Tensor(ndim=3) = R.add(A, B)
+            return C
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_struct_info_annotations_must_be_correct():
+    """StructInfo annotations must be correct
+
+    To be well-formed, the inferred struct info must not conflict with
+    the StructInfo annotations.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(
+            A: R.Tensor(shape=[128, 32], dtype="float32"),
+            B: R.Tensor(shape=[128, 32], dtype="float32"),
+        ):
+            C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B)
+            return C
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_struct_info_may_be_incomplete():
+    """StructInfo annotations may be less specific
+
+    The StructInfo annotations are not required to be an exact match
+    to the inferred StructInfo, and may provide less specific
+    information than the inference would provide.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            A: R.Tensor(shape=[128, 32], dtype="float32"),
+            B: R.Tensor(shape=[128, 32], dtype="float32"),
+        ):
+            C: R.Object = R.add(A, B)
+            return C
+
+    assert rx.analysis.well_formed(Module)
+
+
+def test_incomplete_struct_info_must_be_consistent():
+    """StructInfo annotations must be accurate
+
+    Even though StructInfo annotation may be less specific, the
+    information that they do contain must be correct.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(
+            A: R.Tensor(shape=[128, 32], dtype="float32"),
+            B: R.Tensor(shape=[128, 32], dtype="float32"),
+        ):
+            C: R.Tensor(ndim=3) = R.add(A, B)
+            return C
+
+    assert not rx.analysis.well_formed(Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
index 6005ecb0fa..1df7dcf36f 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -366,8 +366,8 @@ def test_call_packed():
     ) -> R.Object:
         m = T.int64()
         z: R.Tensor((32, m), "float32") = R.multiply(x, y)
-        w: R.Tensor = R.multiply(z, z)
-        q: R.Tensor(ndim=2) = R.add(w, w)
+        w: R.Tensor(ndim=2) = R.multiply(z, z)
+        q: R.Tensor = R.add(w, w)
         t = R.add(w, z)
         sh: R.Shape = R.shape_of(t)
         o: R.Object = R.call_packed(
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 78fc7abdf7..191ea4da5e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -79,7 +79,7 @@ def test_conv1d():
                     out_layout="NCW",
                     out_dtype="float32",
                 )
-                lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+                lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 
6, 1])
                 lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
                 gv: R.Tensor((1, 6, 4), dtype="float32") = lv3
                 R.output(gv)
@@ -171,7 +171,7 @@ def test_conv1d_transpose():
                     out_layout="NCW",
                     out_dtype="float32",
                 )
-                lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+                lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 
6, 1])
                 lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
                 gv: R.Tensor((1, 6, 6), dtype="float32") = lv3
                 R.output(gv)
@@ -263,7 +263,7 @@ def test_conv2d():
                     out_layout="NCHW",
                     out_dtype="float32",
                 )
-                lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
+                lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, 
[1, 6, 1, 1])
                 lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
                 gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
                 R.output(gv)
@@ -355,7 +355,7 @@ def test_conv2d_transpose():
                     out_layout="NCHW",
                     out_dtype="float32",
                 )
-                lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
+                lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2, 
[1, 3, 1, 1])
                 lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, 
lv2)
                 gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3
                 R.output(gv)
@@ -447,7 +447,7 @@ def test_conv3d():
                     out_layout="NCDHW",
                     out_dtype="float32",
                 )
-                lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
+                lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") = 
R.reshape(w2, [1, 6, 1, 1, 1])
                 lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, 
lv2)
                 gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
                 R.output(gv)
diff --git a/tests/python/relax/test_transform_decompose_ops.py 
b/tests/python/relax/test_transform_decompose_ops.py
index 4e5bcb82e9..2564913d79 100644
--- a/tests/python/relax/test_transform_decompose_ops.py
+++ b/tests/python/relax/test_transform_decompose_ops.py
@@ -360,14 +360,14 @@ def test_op_tensor_to_shape():
     @I.ir_module
     class Before:
         @R.function
-        def main(t: R.Tensor(ndim=1, dtype="int64")):
+        def main(t: R.Tensor([3], dtype="int64")):
             gv: R.Shape(ndim=3) = R.tensor_to_shape(t)
             return gv
 
     @I.ir_module
     class Expected:
         @R.function
-        def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3):
+        def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3):
             x = T.int64()
             x_1 = T.int64()
             x_2 = T.int64()
diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py 
b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
index da85423aaf..fa68c16e69 100644
--- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
+++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
@@ -83,7 +83,7 @@ def test_ipc_allreduce_spread_along_reshape():
             alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( 
 # type: ignore
                 R.shape([m, n]), R.dtype("float16"), R.prim_value(0), 
R.str("global")
             )
-            lv1: R.Tensor((m, n), dtype="float16") = R.reshape(alloc, (m * 
n,))  # type: ignore
+            lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(alloc, (m * 
n,))  # type: ignore
             alloc1: R.Tensor((m * n,), dtype="float16") = 
R.builtin.alloc_tensor(  # type: ignore
                 R.shape([m * n]), R.dtype("float16"), R.prim_value(0), 
R.str("global")
             )
@@ -103,7 +103,7 @@ def test_ipc_allreduce_spread_along_reshape():
             alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( 
 # type: ignore
                 R.shape([m, n]), R.dtype("float16"), R.prim_value(0), 
R.str("ipc_memory")
             )
-            lv1: R.Tensor((m, n), dtype="float16") = R.reshape(  # type: ignore
+            lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(  # type: 
ignore
                 alloc, R.shape([m * n])
             )
             alloc1: R.Tensor((m * n,), dtype="float16") = 
R.builtin.alloc_tensor(  # type: ignore
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py 
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 9ea4d21d61..923a8e8d97 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -101,8 +101,8 @@ def test_scatter_from_worker0():
     @tvm.script.ir_module
     class ScatterFromWorker0:
         @R.function
-        def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((5, 10), 
"float32"):
-            gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, 
num_workers=2, axis=1)
+        def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((10,5), 
"float32"):
+            gv0: R.Tensor((10,5), "float32") = R.ccl.scatter_from_worker0(x, 
num_workers=2, axis=1)
             return gv0
 
     @I.ir_module
diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py 
b/tests/python/relax/test_transform_legalize_ops_create_datatype.py
index 7b2b2d2e76..a8af295ac3 100644
--- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py
+++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py
@@ -160,19 +160,19 @@ def test_full_like():
     @tvm.script.ir_module
     class FullLike:
         @R.function
-        def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> 
R.Tensor((2, 3), "float32"):
-            gv: R.Tensor((2, 3), "float32") = R.full_like(x, v)
+        def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> 
R.Tensor((2, 3), "int32"):
+            gv: R.Tensor((2, 3), "int32") = R.full_like(x, v)
             return gv
 
     @tvm.script.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> 
R.Tensor((2, 3), "float32"):
-            gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), 
dtype="float32"))
+        def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> 
R.Tensor((2, 3), "int32"):
+            gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), 
dtype="int32"))
             return gv
 
         @T.prim_func(private=True)
-        def full(rxplaceholder: T.Buffer((), "float32"), T_full: 
T.Buffer((T.int64(2), T.int64(3)), "float32")):
+        def full(rxplaceholder: T.Buffer((), "float32"), T_full: 
T.Buffer((T.int64(2), T.int64(3)), "int32")):
             T.func_attr({"tir.noalias": True})
             for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_full"):
@@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value():
     @tvm.script.ir_module
     class FullLike:
         @R.function
-        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
-            gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, 
"float32"))
+        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+            gv: R.Tensor((2, 3), "int32") = R.full_like(x, R.const(-5, 
"float32"))
             return gv
 
     @tvm.script.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
-            gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), 
dtype="float32"))
+        def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+            gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), 
dtype="int32"))
             return gv
 
         @T.prim_func(private=True)
-        def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+        def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
             T.func_attr({"tir.noalias": True})
             for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_full"):
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])
                     T.reads()
                     T.writes(T_full[ax0, ax1])
-                    T_full[ax0, ax1] = T.float32(-5)
+                    T_full[ax0, ax1] = T.int32(-5)
     # fmt: on
 
     mod = LegalizeOps()(FullLike)
@@ -253,19 +253,19 @@ def test_full_like_symbolic():
     @tvm.script.ir_module
     class FullLike:
         @R.function
-        def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) 
-> R.Tensor(("m", "n"), "float32"):
+        def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) 
-> R.Tensor(("m", "n"), "int32"):
             m = T.int64()
             n = T.int64()
-            gv: R.Tensor((m, n), "float32") = R.full_like(x, v)
+            gv: R.Tensor((m, n), "int32") = R.full_like(x, v)
             return gv
 
     @tvm.script.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) 
-> R.Tensor(("m", "n"), "float32"):
+        def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) 
-> R.Tensor(("m", "n"), "int32"):
             m = T.int64()
             n = T.int64()
-            gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), 
dtype="float32"))
+            gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), 
dtype="int32"))
             return gv
 
         @T.prim_func(private=True)
@@ -273,13 +273,13 @@ def test_full_like_symbolic():
             T.func_attr({"tir.noalias": True})
             m = T.int64()
             n = T.int64()
-            T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
+            T_full = T.match_buffer(var_T_full, [m, n], dtype="int32")
             for i0, i1 in T.grid(m, n):
                 with T.block("T_full"):
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])
                     T.reads(rxplaceholder[()])
                     T.writes(T_full[ax0, ax1])
-                    T_full[ax0, ax1] = rxplaceholder[()]
+                    T_full[ax0, ax1] = T.int32(rxplaceholder[()])
     # fmt: on
 
     mod = LegalizeOps()(FullLike)
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index d0aaddb1ca..2f4da5cf06 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -230,7 +230,7 @@ def test_strided_slice_no_strides():
     class StridedSlice:
         @R.function
         def main(x: R.Tensor((8, 9, 10, 10), "float32")) :
-            gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, 
axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
+            gv: R.Tensor((7, 9, 10, 2), "float32") = R.strided_slice(x, 
axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
             return gv
 
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index ba5d4d7d12..a0ecd3c73d 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -691,9 +691,12 @@ def test_data_dependent_reshape():
     @tvm.script.ir_module
     class DDReshape:
         @R.function
-        def main(x: R.Tensor((3, ), dtype="int64")):
-            lv: R.Shape([3,]) = R.tensor_to_shape(x)
-            gv = R.reshape(x, lv)
+        def main(
+            x: R.Tensor([2], dtype="int64"),
+            y: R.Tensor([16],dtype='float32'),
+        ):
+            lv: R.Shape(ndim=2) = R.tensor_to_shape(x)
+            gv = R.reshape(y, lv)
             return gv
     # fmt: on
 
@@ -704,29 +707,35 @@ def test_data_dependent_reshape():
     # fmt: off
     @I.ir_module
     class Expected:
+        @R.function
+        def main(
+                x: R.Tensor([2], dtype="int64"),
+                y: R.Tensor([16],dtype="float32"),
+        ) -> R.Tensor(ndim=2, dtype="float32"):
+            M = T.int64()
+            N = T.int64()
+            gv = R.call_pure_packed("vm.builtin.tensor_to_shape", x, 
sinfo_args=(R.Shape(ndim=2),))
+            _ = R.match_cast(gv, R.Shape([M,N]))
+            _ = R.shape([M,N])
+            gv_1 = R.call_tir(Expected.reshape, (y,), 
out_sinfo=R.Tensor([M,N], dtype="float32"))
+            return gv_1
+
         @T.prim_func(private=True)
         def reshape(
-            rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape: 
T.handle
+            rxplaceholder: T.Buffer(T.int64(16), "float32"),
+            var_T_reshape: T.handle,
         ):
             T.func_attr({"tir.noalias": True})
-            x = T.int64()
-            T_reshape = T.match_buffer(var_T_reshape, (x,), "int64")
-            # with T.block("root"):
-            for ax0 in range(x):
+            M = T.int64()
+            N = T.int64()
+            T_reshape = T.match_buffer(var_T_reshape, [M,N], "float32")
+            for i,j in T.grid(M,N):
                 with T.block("T_reshape"):
-                    v_ax0 = T.axis.spatial(x, ax0)
-                    T.reads(rxplaceholder[v_ax0 % T.int64(3)])
-                    T.writes(T_reshape[v_ax0])
-                    T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)]
+                    vi,vj = T.axis.remap('SS',[i,j])
+                    T.reads(rxplaceholder[(vi*N + vj) % 16])
+                    T.writes(T_reshape[vi,vj])
+                    T_reshape[vi,vj] = rxplaceholder[(vi*N + vj) % 16]
 
-        @R.function
-        def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, 
dtype="int64"):
-            x_1 = T.int64()
-            gv: R.Shape([3]) = 
R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),))
-            y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1]))
-            lv: R.Shape([x_1]) = R.shape([x_1])
-            gv_1 = R.call_tir(Expected.reshape, (x,), 
out_sinfo=R.Tensor((x_1,), dtype="int64"))
-            return gv_1
     # fmt: on
     tvm.ir.assert_structural_equal(out_mod, Expected)
 
@@ -914,7 +923,7 @@ def test_squeeze_no_axis():
     class Squeeze:
         @R.function
         def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) :
-            gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x)
+            gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x)
             return gv
 
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 92d139d23b..d03d48968d 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -33,7 +33,7 @@ def test_conv1d():
     class Conv1d:
         @R.function
         def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 
3), "float32")) -> R.Tensor((2, 64, 13), "float32"):
-            gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w, 
strides=(2,), padding=(1,), dilation=(2,), groups=8)
+            gv: R.Tensor((2, 64, 13), "float32") = R.nn.conv1d(x, w, 
strides=(2,), padding=(1,), dilation=(2,), groups=8)
             return gv
 
     @tvm.script.ir_module
@@ -210,7 +210,7 @@ def test_conv2d():
     class Conv2d:
         @R.function
         def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 
16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"):
-            gv: R.Tensor((2, 4, 13, 13), "float32") = R.nn.conv2d(x, w, 
strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8)
+            gv: R.Tensor((2, 64, 13, 13), "float32") = R.nn.conv2d(x, w, 
strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8)
             return gv
 
     @tvm.script.ir_module
@@ -3298,20 +3298,32 @@ def test_nll_loss():
     @tvm.script.ir_module
     class NLLLoss:
         @R.function
-        def main(predictions: R.Tensor((2, 3, 4, 5), "float32"), targets: 
R.Tensor((2, 4, 5), "int64"), weights: R.Tensor((4,), "float32")) -> 
R.Tensor((), "float32"):
-            gv: R.Tensor((), "float32") = R.nn.nll_loss(predictions, targets, 
weights, reduction="mean", ignore_index=-1)
+        def main(
+                predictions: R.Tensor((2, 3, 4, 5), "float32"),
+                targets: R.Tensor((2, 4, 5), "int64"),
+                weights: R.Tensor((3,), "float32"),
+        ) -> R.Tensor((), "float32"):
+            gv = R.nn.nll_loss(predictions, targets, weights, 
reduction="mean", ignore_index=-1)
             return gv
 
     @tvm.script.ir_module
     class Expected:
         @R.function
-        def main(predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), 
targets: R.Tensor((2, 4, 5), dtype="int64"), weights: R.Tensor((4,), 
dtype="float32"),) -> R.Tensor((), dtype="float32"):
-            # block 0
+        def main(
+                predictions: R.Tensor((2, 3, 4, 5), dtype="float32"),
+                targets: R.Tensor((2, 4, 5), dtype="int64"),
+                weights: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
             gv = R.call_tir(Expected.nll_loss, (predictions, targets, 
weights), R.Tensor((), dtype="float32"))
             return gv
 
         @T.prim_func(private=True)
-        def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), 
T.int64(4), T.int64(5)), "int64"), rxplaceholder_2: T.Buffer(T.int64(4), 
"float32"), T_divide: T.Buffer((), "float32"),):
+        def nll_loss(
+                predictions: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float32"),
+                targets: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), 
"int64"),
+                weights: T.Buffer(T.int64(3), "float32"),
+                output: T.Buffer((), "float32"),
+        ):
             # function attr dict
             T.func_attr({"tir.noalias": True})
             # body
@@ -3323,9 +3335,9 @@ def test_nll_loss():
             for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
                 with T.block("nll_loss"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                    T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], 
rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], 
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]])
+                    T.reads(targets[v_ax0, v_ax1, v_ax2], predictions[v_ax0, 
targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], weights[targets[v_ax0, v_ax1, 
v_ax2]])
                     T.writes(nll_loss[v_ax0, v_ax1, v_ax2])
-                    nll_loss[v_ax0, v_ax1, v_ax2] = 
T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - 
rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) * 
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0))
+                    nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, 
v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - predictions[v_ax0, targets[v_ax0, 
v_ax1, v_ax2], v_ax1, v_ax2]) * weights[targets[v_ax0, v_ax1, v_ax2]], 
T.float32(0))
             for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
                 with T.block("nll_loss_red"):
                     v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
@@ -3337,9 +3349,9 @@ def test_nll_loss():
             for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
                 with T.block("nll_loss_1"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                    T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2], 
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]])
+                    T.reads(targets[v_ax0, v_ax1, v_ax2], 
weights[targets[v_ax0, v_ax1, v_ax2]])
                     T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2])
-                    nll_loss_1[v_ax0, v_ax1, v_ax2] = 
T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), 
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0))
+                    nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0, 
v_ax1, v_ax2] != T.int64(-1), weights[targets[v_ax0, v_ax1, v_ax2]], 
T.float32(0))
             for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
                 with T.block("nll_loss_red_1"):
                     v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
@@ -3351,8 +3363,8 @@ def test_nll_loss():
             with T.block("T_divide"):
                 vi = T.axis.spatial(1, T.int64(0))
                 T.reads(nll_loss_red[()], nll_loss_red_1[()])
-                T.writes(T_divide[()])
-                T_divide[()] = nll_loss_red[()] / nll_loss_red_1[()]
+                T.writes(output[()])
+                output[()] = nll_loss_red[()] / nll_loss_red_1[()]
     # fmt: on
     mod = LegalizeOps()(NLLLoss)
     tvm.ir.assert_structural_equal(mod, Expected)
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 2a28151dbe..f8dab89815 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -999,8 +999,8 @@ def test_variance_no_keepdims():
     @tvm.script.ir_module
     class Variance:
         @R.function
-        def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4, 
1), "float32"):
-            gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3], 
keepdims=False)
+        def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), 
"float32"):
+            gv: R.Tensor((3, 4), "float32") = R.variance(x, [0, 3], 
keepdims=False)
             return gv
 
     @I.ir_module
diff --git a/tests/python/relax/test_transform_realize_vdevice.py 
b/tests/python/relax/test_transform_realize_vdevice.py
index 4c530d5e49..fa64282184 100644
--- a/tests/python/relax/test_transform_realize_vdevice.py
+++ b/tests/python/relax/test_transform_realize_vdevice.py
@@ -61,8 +61,9 @@ def test_dataflow_binding():
                 y1 = y
                 x2 = x1
                 y2 = y1
-                lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
-                gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z)
+                x2 = R.hint_on_device(x2, tvm.cpu())
+                lv0 = R.add(x2, y2)
+                gv = R.multiply(lv0, z)
                 R.output(gv)
             return gv
 
@@ -91,6 +92,7 @@ def test_dataflow_binding():
                 y1: R.Tensor((2, 3), "float32", "llvm") = y
                 x2: R.Tensor((2, 3), "float32", "llvm") = x1
                 y2: R.Tensor((2, 3), "float32", "llvm") = y1
+                x2: R.Tensor((2, 3), "float32", "llvm") = x2
                 lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
                 gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z)
                 R.output(gv)
@@ -121,7 +123,8 @@ def test_binding():
             y1 = y
             x2 = x1
             y2 = y1
-            s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
+            x2 = R.hint_on_device(x2, tvm.cpu())
+            s = R.add(x2, y2)
             m = R.multiply(s, z)
             return m
 
@@ -146,6 +149,7 @@ def test_binding():
             y1: R.Tensor((2, 3), "float32", "llvm") = y
             x2: R.Tensor((2, 3), "float32", "llvm") = x1
             y2: R.Tensor((2, 3), "float32", "llvm") = y1
+            x2: R.Tensor((2, 3), "float32", "llvm") = x2
             s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
             m: R.Tensor((2, 3), "float32", "llvm") = R.multiply(s, z)
             return m
@@ -275,10 +279,11 @@ def test_multi_device():
             z: R.Tensor((2, 3), "float32"),
         ) -> R.Tensor((2, 3), "float32", "cuda"):
             with R.dataflow():
-                lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y)
+                lv0 = R.add(x, y)
+                lv0 = R.hint_on_device(lv0, tvm.cpu())
                 lv1 = R.to_vdevice(lv0, "cuda")
                 lv2 = R.add(z, z)
-                gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2)
+                gv = R.multiply(lv1, lv2)
                 R.output(gv)
             return gv
 
@@ -304,6 +309,7 @@ def test_multi_device():
         ) -> R.Tensor((2, 3), "float32", "cuda"):
             with R.dataflow():
                 lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y)
+                lv0: R.Tensor((2, 3), "float32", "llvm") = lv0
                 lv1: R.Tensor((2, 3), "float32", "cuda") = R.to_vdevice(lv0, 
"cuda")
                 lv2: R.Tensor((2, 3), "float32", "cuda") = R.add(z, z)
                 gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index f9e632d348..1150827b19 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1386,11 +1386,11 @@ def test_add():
             )
             cls.cumsum(probs, lv1, alloc1)
             cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = 
alloc1
-            lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = 
R.call_packed(
+            lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") = 
R.call_packed(
                 "vm.builtin.reshape",
                 cumsum,
                 R.shape([batch_size, vocab_size]),
-                sinfo_args=(R.Tensor((batch_size, vocab_size), 
dtype="float"),),
+                sinfo_args=(R.Tensor((batch_size, vocab_size), 
dtype="float32"),),
             )
             return lv1_1
 
@@ -1403,7 +1403,7 @@ def test_add():
         @R.function
         def main(
             probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")
-        ) -> R.Tensor(("batch_size", "vocab_size"), dtype="int32"):
+        ) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"):
             batch_size = T.int64()
             vocab_size = T.int64()
             R.func_attr(
@@ -1437,7 +1437,7 @@ def test_add():
             )
             cls.cumsum(probs, lv1, alloc1)
             cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = 
alloc1
-            lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = 
R.call_packed(
+            lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") = 
R.call_packed(
                 "vm.builtin.reshape",
                 cumsum,
                 R.shape([batch_size, vocab_size]),
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py 
b/tests/python/relax/test_transform_to_mixed_precision.py
index ed10fc95c7..658f80a06e 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -906,15 +906,15 @@ def test_conv2d_bias_fp32():
         ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
             # block 0
             with R.dataflow():
-                lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+                lv142: R.Tensor((1, 512, 62, 62), dtype="float32") = 
R.nn.conv2d(
                     x,
                     w,
                     strides=[1, 1],
                     padding=[0, 0, 0, 0],
                     out_dtype="float32",
                 )
-                lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = 
R.reshape(bias, (1, 512, 1, 1))
-                lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = 
R.add(lv142, lv143)
+                lv143: R.Tensor((1, 512, 1, 1), dtype="float32") = 
R.reshape(bias, (1, 512, 1, 1))
+                lv144: R.Tensor((1, 512, 62, 62), dtype="float32") = 
R.add(lv142, lv143)
                 R.output(lv144)
             return lv144
 
@@ -1001,15 +1001,15 @@ def test_convert_sig():
         ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
             # block 0
             with R.dataflow():
-                lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+                lv142: R.Tensor((1, 512, 62, 62), dtype="float32") = 
R.nn.conv2d(
                     x,
                     w,
                     strides=[1, 1],
                     padding=[0, 0, 0, 0],
                     out_dtype="float32",
                 )
-                lv143: R.Tensor((1, 4, 1, 1), dtype="float32") = 
R.reshape(bias, (1, 512, 1, 1))
-                lv144: R.Tensor((1, 4, 64, 64), dtype="float32") = 
R.add(lv142, lv143)
+                lv143: R.Tensor((1, 512, 1, 1), dtype="float32") = 
R.reshape(bias, (1, 512, 1, 1))
+                lv144: R.Tensor((1, 512, 62, 62), dtype="float32") = 
R.add(lv142, lv143)
                 R.output(lv144)
             return lv144
 
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index fa62d14848..3e64c928ae 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -882,8 +882,8 @@ def test_annotation():
     ) -> R.Object:
         m = T.int64()
         z: R.Tensor((32, m), "float32") = R.multiply(x, y)
-        w: R.Tensor = R.multiply(z, z)
-        q: R.Tensor(ndim=2) = R.add(w, w)
+        w: R.Tensor(ndim=2) = R.multiply(z, z)
+        q: R.Tensor = R.add(w, w)
         t = R.add(w, z)
         sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape)
         lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh)
@@ -902,9 +902,9 @@ def test_annotation():
     sh = bindings[4].var
 
     _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32"))
-    _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1))
-    _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2))
-    _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1))
+    _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2))
+    _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1))
+    _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2))
     _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1))
     _check_struct_info(bindings[5], relax.TensorStructInfo(sh))
     _check_struct_info(bindings[6], relax.ObjectStructInfo())
diff --git a/tests/python/relax/test_vm_cuda_graph.py 
b/tests/python/relax/test_vm_cuda_graph.py
index 49ebcc1d05..b6c8cdfdee 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -36,13 +36,13 @@ class Module:
         R.func_attr({"global_symbol": "main"})
         gv: R.Tuple(R.Object, R.Object) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", 
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, 
R.Object),))
         storage: R.Object = gv[0]
-        alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, 
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+        alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), 
R.dtype("float32"))
         _: R.Tuple = cls.add(x, alloc)
         storage1: R.Object = gv[1]
         gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, 
storage1, storage)
         gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, gv1, R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),))
         storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("uint8"))
-        alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, 
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+        alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 
16)), R.dtype("float32"))
         lv4: R.Tensor((16, 16), dtype="float32") = gv2[0]
         _3: R.Tuple = cls.add(lv4, alloc3)
         lv5: R.Tensor(dtype="float32") = alloc3
@@ -71,12 +71,12 @@ class Module:
         cls = Module
         R.func_attr({"global_symbol": "cuda_graph_capture"})
         lv0: R.Tensor((16, 16), dtype="float32") = alloc
-        alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1, 
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+        alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 
16)), R.dtype("float32"))
         _1: R.Tuple = cls.add(lv0, alloc1)
         lv1: R.Tensor(dtype="float32") = alloc1
         lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,)
         lv3: R.Tensor(dtype="float32") = lv2[0]
-        alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, 
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+        alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 
16)), R.dtype("float32"))
         _2: R.Tuple = cls.add(lv3, alloc2)
         lv4: R.Tensor(dtype="float32") = alloc2
         gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,)
diff --git a/tests/python/relax/test_vm_multi_device.py 
b/tests/python/relax/test_vm_multi_device.py
index ec2fbd1cdf..73c78d70f0 100644
--- a/tests/python/relax/test_vm_multi_device.py
+++ b/tests/python/relax/test_vm_multi_device.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test eliminate common subexpr pass"""
+
 from typing import List
 import tvm
 from tvm import relax
@@ -61,11 +62,12 @@ def test_multi_cpu():
             z: R.Tensor((4, 5), "float32"),
         ) -> R.Tensor((2, 5), "float32"):
             with R.dataflow():
-                lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y)  # 
noqa: F722
+                lv0 = R.matmul(x, y)
+                lv0 = R.hint_on_device(lv0, tvm.cpu(0))
                 lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice(  # 
noqa: F722
-                    lv0, "llvm:1"  # noqa: F722
+                    lv0, "llvm:1"
                 )
-                gv = R.matmul(lv1, z)  # noqa: F722
+                gv = R.matmul(lv1, z)
                 R.output(gv)
             return gv
 
@@ -109,11 +111,13 @@ def test_multi_gpu():
             with R.dataflow():
                 lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b)  # 
noqa: F722
                 lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice(  # 
noqa: F722
-                    lv0, "cuda:1"  # noqa: F722
+                    lv0,
+                    "cuda:1",  # noqa: F722
                 )
                 lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c)  
# noqa: F722
                 lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice(  # 
noqa: F722
-                    lv2, "cuda:2"  # noqa: F722
+                    lv2,
+                    "cuda:2",  # noqa: F722
                 )
                 gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d)  
# noqa: F722
                 R.output(gv)


Reply via email to