This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 491a0f69aa [Relax] Require correct input/output shapes `R.call_tir` 
(#17285)
491a0f69aa is described below

commit 491a0f69aabcf812cc552df7666038414ca79a8f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 6 08:32:31 2024 -0500

    [Relax] Require correct input/output shapes `R.call_tir` (#17285)
    
    Prior to this commit, the Relax well-formed checker validated
    arguments provided to Relax functions, but did not validate arguments
    provided to `R.call_tir`.  As a result, incorrect arguments from Relax
    to TIR would not be checked until runtime, if at all.
    
    This commit updates the well-formed checker to verify that
    `R.call_tir` has received the correct arguments, and has the correct
    output shape specified in the `out_sinfo` parameter.
    
    Initial implementation performed the validation as part of
    `FNormalize`, to maximize coverage of this check.  This increased
    end-to-end compilation time by ~10%, and so the check was requested to
    be restricted to the well-formed checker.  Expensive operator-specific
    validation is now performed in the new `FValidate` attribute.
---
 include/tvm/relax/op_attr_types.h                  |  27 ++
 src/relax/analysis/well_formed.cc                  |  11 +
 src/relax/op/op.cc                                 | 291 +++++++++++-
 src/relax/transform/fuse_tir.cc                    |   3 +-
 ...est_distributed_transform_propagate_sharding.py |   8 -
 tests/python/relax/test_analysis_well_formed.py    | 514 ++++++++++++++++++++-
 tests/python/relax/test_ast_printer.py             |   9 +-
 tests/python/relax/test_dataflow_inplace.py        |  10 +-
 tests/python/relax/test_dataflow_pattern.py        |   2 +-
 tests/python/relax/test_frontend_dynamo.py         |   7 +-
 tests/python/relax/test_frontend_nn_op.py          |  18 +-
 tests/python/relax/test_transform.py               |   6 +-
 .../relax/test_transform_dead_code_elimination.py  |  30 +-
 tests/python/relax/test_transform_fuse_ops.py      |   8 +-
 .../relax/test_transform_fuse_ops_by_pattern.py    |  18 +-
 .../relax/test_transform_lazy_transform_params.py  |  20 +-
 .../test_transform_rewrite_dataflow_reshape.py     |  25 +-
 tests/python/relax/test_tvmscript_parser.py        |  15 +-
 tests/python/relax/test_vm_build.py                |  12 +-
 19 files changed, 928 insertions(+), 106 deletions(-)

diff --git a/include/tvm/relax/op_attr_types.h 
b/include/tvm/relax/op_attr_types.h
index 291bee597c..0ddc2baefb 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -56,6 +56,14 @@ using FCallPacked = String;
  * expressed in multiple syntactically valid and semantically
  * equivalent forms, to normalize to a single representation.
  *
+ * Note: `FNormalize` is applied for each expression as part of the
+ *    `relax::BlockBuilder`.  While operator-specific validation may
+ *    be performed within the `FNormalize` implementation, ensuring
+ *    that errors are caught as early as possible, this should only be
+ *    used when validation is fast to apply.  If the validation logic
+ *    may be slow, it should instead be implemented in `FValidate`,
+ *    which is only run as part of the well-formed checker.
+ *
  * \param bb The BlockBuilder context.
  *
  * \param call The call to be normalized.  It is provided by-value, to
@@ -63,6 +71,25 @@ using FCallPacked = String;
  */
 using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call 
call)>;
 
+/*!
+ * \brief The function type of a validation function.
+ *
+ * A validation function is used to define constraints that should be
+ * verified for an operator as part of the well-formed checker.
+ *
+ * Note: `FValidate` is only applied as part of the well-formed
+ *    checker.  While this minimizes overhead while compiling Relax,
+ *    this delay between generating an ill-formed `relax::Call` and
+ *    identifying the ill-formed call may complicate debugging.  If
+ *    the validation logic is very fast to check, and doing so would
+ *    not introduce a signficant overhead, consider validating as part
+ *    of `FNormalize`, which is applied by the block builder for each
+ *    `relax::Call`.
+ *
+ * \param call The call to be validated.
+ */
+using FValidate = runtime::TypedPackedFunc<void(const Call& call)>;
+
 /*! \brief The function type of a legalization function.
  *
  * A legalization function is used to replace a `relax::Call` with
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 626fadda27..235059ece2 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -352,6 +352,16 @@ class WellFormedChecker : public relax::ExprVisitor,
             << after_normalize);
       }
     }
+
+    if (auto func_validate = op_map_validate_.get(call->op, nullptr); 
func_validate != nullptr) {
+      try {
+        func_validate(GetRef<Call>(call));
+      } catch (std::exception& err) {
+        Malformed(Diagnostic::Error(call) << "Operator-specific validation 
(FValidate) for "
+                                          << call->op << " identified error: 
\n"
+                                          << err.what());
+      }
+    }
   }
 
   void VisitExpr_(const IfNode* op) final {
@@ -574,6 +584,7 @@ class WellFormedChecker : public relax::ExprVisitor,
   std::unordered_map<tir::Var, const FunctionNode*> symbolic_var_func_map_;
 
   tvm::OpAttrMap<FNormalize> op_map_normalize_ = 
Op::GetAttrMap<FNormalize>("FNormalize");
+  tvm::OpAttrMap<FValidate> op_map_validate_ = 
Op::GetAttrMap<FValidate>("FValidate");
 };
 
 bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info) {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 0a840248ff..3e0f0eba31 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -18,6 +18,7 @@
  */
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/attrs/op.h>
+#include <tvm/relax/distributed/struct_info.h>
 #include <tvm/relax/expr.h>
 #include <tvm/relax/utils.h>
 #include <tvm/relay/op.h>
@@ -242,15 +243,195 @@ 
TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla
 
 // call_tir
 
+/* If possible, infer a legal value of `arg_sinfo`
+ *
+ * The `R.call_tir` operator and its variants accept an `arg_sinfo`
+ * parameter, which specifies the shape of the tensor or tensors
+ * returned by a PrimFunc.  This output shape must be compatible with
+ * the shape defined by the PrimFunc's signature.
+ *
+ * For dynamic shapes, it is not always possible to infer the output
+ * of a TIR PrimFunc from its inputs.  For example, a PrimFunc that
+ * accepts input buffer `T.Buffer([16], "float32")` and output buffer
+ * `T.Buffer([M, N], "float32")` infers the values of `M` and `N` from
+ * the shape of the provided output buffer.
+ *
+ * If the arguments provided are not compatible with the PrimFunc's
+ * signature, an error will be raised.  If the arguments are
+ * compatible with the PrimFunc's signature, but are not sufficient to
+ * determine the output's StructInfo, then `NullOpt` will be returned.
+ *
+ * \param func_sinfo The StructInfo of the TIR callee.
+ * \param arg_sinfo The StructInfo of the argument tuple.
+ * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument,
+ *     if present.
+ * \param opt_inplace_indices For `R.call_tir_inplace`, an array of
+ *     indices indicating which outputs are constructed from in-place
+ *     mutation of the inputs.  See
+ *     `CallTIRInplaceAttrs::inplace_indices` for more details.
+ *
+ * \return The `arg_sinfo`, if it can be inferred from the arguments.
+ *     Otherwise, NullOpt.
+ */
+static Optional<StructInfo> InferCallTIROutputStructInfoFromArguments(
+    StructInfo func_sinfo, StructInfo arg_sinfo, Optional<StructInfo> 
packed_ints_sinfo,
+    Optional<Array<Integer>> opt_inplace_indices) {
+  auto opt_callee_sinfo = func_sinfo.as<FuncStructInfo>();
+  CHECK(opt_callee_sinfo) << "TypeError: "
+                          << "The first argument to `R.call_tir` must be a 
function, "
+                          << "but instead received argument of type " << 
func_sinfo;
+  auto callee_sinfo = opt_callee_sinfo.value();
+
+  CHECK(callee_sinfo->params.defined())
+      << "ValueError: "
+      << "The first argument to `R.call_tir` must be a function "
+      << "with known argument types.  "
+      << "However, the first argument was of type " << callee_sinfo;
+  auto callee_params = callee_sinfo->params.value();
+
+  const TupleStructInfoNode* args = arg_sinfo.as<TupleStructInfoNode>();
+  CHECK(args) << "TypeError: "
+              << "The second argument to `R.call_tir` must be a tuple, "
+              << "but instead received expression of type " << arg_sinfo;
+
+  // R.call_tir expects the PrimFunc to have three groups of arguments.
+  //
+  // 1. Input arguments that are explicitly provided as Relax arguments.
+  // 2. Output tensor arguments.
+  // 3. Shape arguments, represented as `T.int64` in the PrimFunc, and
+  //    as an optional ShapeExpr argument in the `relax::Call` node.
+  //
+  // In order to determine the return type of `R.call_tir`, we must
+  // identify the PrimFunc arguments that will be in group (2).
+  size_t num_input_arguments = args->fields.size();
+  size_t num_trailing_int_arguments = 0;
+  const ShapeStructInfoNode* packed_tuple_sinfo = nullptr;
+  if (packed_ints_sinfo) {
+    auto packed_sinfo = packed_ints_sinfo.value();
+    packed_tuple_sinfo = packed_sinfo.as<ShapeStructInfoNode>();
+    CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim())
+        << "TypeError: "
+        << "The third argument to `R.call_tir`, if present, "
+        << "must be a ShapeTuple with known dimensionality.  "
+        << "However, the argument received was of type " << packed_sinfo;
+    num_trailing_int_arguments = packed_tuple_sinfo->ndim;
+  } else {
+    num_trailing_int_arguments = 0;
+  }
+
+  CHECK_LE(num_input_arguments + num_trailing_int_arguments, 
callee_params.size())
+      << "ValueError: "
+      << "R.call_tir attempted to call a function using " << 
num_input_arguments
+      << " input arguments and " << num_trailing_int_arguments << " trailing 
integer arguments.  "
+      << "However, the callee only accepts " << callee_params.size() << " 
arguments in total.";
+
+  // While Relax can specify a distributed tensor, TIR cannot.  The
+  // current implementation does not support determining the output
+  // shape for `R.dist.call_tir` calls, as it depends on the lowering
+  // of DistIR into regular Relax.
+  std::function<bool(StructInfo)> contains_dtensor = 
[&contains_dtensor](StructInfo sinfo) -> bool {
+    if (sinfo.as<distributed::DTensorStructInfoNode>()) {
+      return true;
+    } else if (auto tuple = sinfo.as<TupleStructInfoNode>()) {
+      return std::any_of(tuple->fields.begin(), tuple->fields.end(), 
contains_dtensor);
+    } else {
+      return false;
+    }
+  };
+  if (contains_dtensor(arg_sinfo)) {
+    return NullOpt;
+  }
+
+  // At this point, the return types are known.  However, the shapes
+  // in `callee_params` may contain dynamic shape parameters that are
+  // not present in the caller's scope.  The `DeriveCallRetStructInfo`
+  // utility can infer the value of dynamic parameters in
+  // `FuncStructInfoNode::ret` based on definitions in
+  // `FuncStructInfoNode::params`, inferring the correct values in the
+  // caller's scope.
+  //
+  // Since the callee of `R.call_tir` is provided with output
+  // arguments, where `DeriveCallRetStructInfo` requires a callee that
+  // produces its own outputs, a dummy function signature and
+  // arguments are used.
+
+  auto dummy_callee_sinfo = [&]() -> FuncStructInfo {
+    Array<StructInfo> dummy_params(callee_params.begin(),
+                                   callee_params.begin() + 
num_input_arguments);
+
+    for (size_t i = callee_params.size() - num_trailing_int_arguments; i < 
callee_params.size();
+         i++) {
+      dummy_params.push_back(callee_params[i]);
+    }
+
+    Array<StructInfo> dummy_ret(callee_params.begin() + num_input_arguments,
+                                callee_params.end() - 
num_trailing_int_arguments);
+
+    if (opt_inplace_indices) {
+      // For R.call_tir_inplace, the `inplace_indices` are used to
+      // indicate which elements of the `out_sinfo` will be generated
+      // as in-place mutation from an input.  For any in-place
+      // mutation, the parameter's StructInfo must be inserted into
+      // `out_sinfo`.
+      auto inplace_indices = opt_inplace_indices.value();
+      for (size_t i = 0; i < inplace_indices.size(); i++) {
+        auto inplace_input_index = inplace_indices[i]->value;
+        if (inplace_input_index >= 0) {
+          dummy_ret.insert(dummy_ret.begin() + i, 
callee_params[inplace_input_index]);
+        }
+      }
+    }
+
+    auto dummy_out_sinfo = [&]() -> StructInfo {
+      if (dummy_ret.size() == 1) {
+        return dummy_ret[0];
+      } else {
+        return TupleStructInfo(dummy_ret);
+      }
+    }();
+
+    return FuncStructInfo(dummy_params, dummy_out_sinfo);
+  }();
+
+  auto dummy_args = [&]() -> Array<Expr> {
+    Array<Expr> dummy_args = args->fields.Map(
+        [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", 
sinfo); });
+
+    for (size_t i = 0; i < num_trailing_int_arguments; i++) {
+      ICHECK(packed_tuple_sinfo);
+      PrimStructInfo dummy_arg_sinfo = [&]() {
+        if (packed_tuple_sinfo->values) {
+          return PrimStructInfo(packed_tuple_sinfo->values.value()[i]);
+        } else {
+          return PrimStructInfo(DataType::Int(64));
+        }
+      }();
+      dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_sinfo));
+    }
+
+    return dummy_args;
+  }();
+
+  auto derived_ret_sinfo = DeriveCallRetStructInfo(
+      dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), 
dummy_args),
+      BlockBuilder::Create(NullOpt));
+
+  return derived_ret_sinfo;
+}
+
 StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
   if (call->sinfo_args.size() != 1) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "sinfo_args should have exactly 1 output struct 
info.");
   }
   CHECK(call->args[0]->IsInstance<GlobalVarNode>())
-      << "call_tir expects the first argument to be a GlobalVar referring to a 
TIR PrimFunc. "
-      << "However, gets " << call->args[0];
-  return call->sinfo_args[0];
+      << "R.call_tir expects the first argument to be a GlobalVar referring to 
a TIR PrimFunc. "
+      << "However, the argument " << call->args[0] << " instead has type "
+      << call->args[0]->GetTypeKey();
+
+  StructInfo explicit_sinfo = call->sinfo_args[0];
+
+  return explicit_sinfo;
 }
 
 Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) {
@@ -264,23 +445,37 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) 
{
       << "or three arguments [callee, arg_tuple, tir_args], "
       << "but " << call << " has " << call->args.size() << " arguments.";
 
-  Expr arg_expr = call->args[1];
+  auto callee = call->args[0];
+  CHECK(callee->struct_info_.as<FuncStructInfoNode>())
+      << "Operation " << call->op << " expects the first argument to be a TIR 
callee.  "
+      << "However, the first argument " << callee << " has struct info " << 
callee->struct_info_;
 
-  CHECK(arg_expr->struct_info_.as<TupleStructInfoNode>())
-      << "Operation " << call->op << " expects the second argument to be a 
tuple of relax Expr.  "
-      << "However, the second argument " << arg_expr << " has struct info "
-      << arg_expr->struct_info_ << ".";
+  Expr arg_tuple = call->args[1];
 
-  if (arg_expr.as<TupleNode>()) {
-    return std::move(call);
-  }
+  CHECK(arg_tuple->struct_info_.as<TupleStructInfoNode>())
+      << "Operation " << call->op << " expects the second argument to be a 
tuple of relax Expr.  "
+      << "However, the second argument " << arg_tuple << " has struct info "
+      << arg_tuple->struct_info_ << ".";
 
-  CHECK(arg_expr.as<VarNode>())
+  CHECK(arg_tuple.as<TupleNode>() || arg_tuple.as<VarNode>())
       << "Operation " << call->op << " must hold its arguments as an in-line 
tuple.  "
-      << "However, " << call << " has arguments " << arg_expr
+      << "However, " << call << " has arguments " << arg_tuple
       << ", which is neither an in-line tuple, "
       << "nor a variable binding that may be normalized to an in-line tuple.";
 
+  if (call->args.size() > 2) {
+    Expr packed_ints = call->args[2];
+    CHECK(packed_ints->struct_info_.as<ShapeStructInfoNode>())
+        << "Operation " << call->op << " expects the optional third argument, "
+        << "if present, to be a ShapeTuple.  "
+        << "However, the third argument " << packed_ints << " has struct info "
+        << packed_ints->struct_info_;
+  }
+
+  CHECK_EQ(call->sinfo_args.size(), 1)
+      << "R.call_tir should have exactly one `sinfo_args` parameter, "
+      << "which defines the output of the PrimFunc.";
+
   auto unwrap_binding = [&ctx](Expr expr) -> Optional<Expr> {
     if (auto var = expr.as<Var>()) {
       if (auto bound_value = ctx->LookupBinding(var.value())) {
@@ -290,14 +485,21 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) 
{
     return NullOpt;
   };
 
-  while (auto unwrapped = unwrap_binding(arg_expr)) {
-    arg_expr = unwrapped.value();
-  }
+  Tuple new_arg_tuple = [&]() {
+    // No replacement required.  The argument tuple is already
+    // provided as an in-line tuple.
+    if (auto opt = arg_tuple.as<Tuple>()) {
+      return opt.value();
+    }
+
+    Expr unwrapped_tuple = arg_tuple;
+    while (auto unwrapped = unwrap_binding(unwrapped_tuple)) {
+      unwrapped_tuple = unwrapped.value();
+    }
 
-  Tuple new_arg_expr = [&]() {
     // Preferred replacement.  The argument tuple is provided as a
     // variable, but we know the value bound to that variable.
-    if (auto opt = arg_expr.as<Tuple>()) {
+    if (auto opt = unwrapped_tuple.as<Tuple>()) {
       return opt.value();
     }
 
@@ -306,20 +508,60 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) 
{
     // example, if a relax function accepted a tuple as an parameter,
     // then provided that same tuple as an argument to call_tir.
     Array<Expr> tuple_elements;
-    size_t num_fields = 
Downcast<TupleStructInfo>(arg_expr->struct_info_)->fields.size();
+    size_t num_fields = 
Downcast<TupleStructInfo>(arg_tuple->struct_info_)->fields.size();
     for (size_t i = 0; i < num_fields; i++) {
-      tuple_elements.push_back(TupleGetItem(arg_expr, i));
+      tuple_elements.push_back(TupleGetItem(arg_tuple, i));
     }
     return Tuple(tuple_elements);
   }();
 
-  auto new_args = call->args;
-  new_args.Set(1, new_arg_expr);
-  call.CopyOnWrite()->args = new_args;
+  if (!new_arg_tuple.same_as(arg_tuple)) {
+    auto new_args = call->args;
+    new_args.Set(1, new_arg_tuple);
+    call.CopyOnWrite()->args = new_args;
+  }
 
   return std::move(call);
 }
 
+void ValidateCallTIR(Call call) {
+  // This function is used for validation of `relax.call_tir`,
+  // along with the variants `relax.call_tir_with_grad` and
+  // `relax.call_tir_inplace`.  Therefore, all error messages should
+  // be written in terms of `call->op`, and should not explicitly
+  // reference the `relax.call_tir` operator.`
+
+  auto callee = call->args[0];
+  Expr arg_tuple = call->args[1];
+
+  auto packed_int_sinfo = [&]() -> Optional<StructInfo> {
+    if (call->args.size() <= 2) {
+      return NullOpt;
+    } else {
+      return GetStructInfo(call->args[2]);
+    }
+  }();
+
+  auto opt_inplace_indices = [&]() -> Optional<Array<Integer>> {
+    if (const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>()) {
+      return attrs->inplace_indices;
+    } else {
+      return NullOpt;
+    }
+  }();
+
+  StructInfo explicit_sinfo = call->sinfo_args[0];
+  auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments(
+      GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, 
opt_inplace_indices);
+  if (inferred_sinfo.defined()) {
+    CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo))
+        << "TypeError: "
+        << "The `out_sinfo` argument for R.call_tir must be compatible with 
the PrimFunc.  "
+        << "However, the PrimFunc's signature implies that the output should 
be " << inferred_sinfo
+        << ", but the `out_sinfo` argument was " << explicit_sinfo;
+  }
+}
+
 RELAY_REGISTER_OP("relax.call_tir")
     .set_num_inputs(3)
     .add_argument("func", "Expr", "The destination-passing-style function.")
@@ -329,6 +571,7 @@ RELAY_REGISTER_OP("relax.call_tir")
                   "args if unused")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
     .set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
+    .set_attr<FValidate>("FValidate", ValidateCallTIR)
     .set_attr<Bool>("FPurity", Bool(true));
 
 Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
@@ -374,6 +617,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
                   "args if unused")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
     .set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
+    .set_attr<FValidate>("FValidate", ValidateCallTIR)
     .set_attr<Bool>("FPurity", Bool(true));
 
 Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> 
out_sinfo_list,
@@ -514,6 +758,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
                   "args if unused")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
     .set_attr<FNormalize>("FNormalize", NormalizeCallTIRInPlace)
+    .set_attr<FValidate>("FValidate", ValidateCallTIR)
     // Warning: considered pure, but it has the potential to create visible 
effects!
     // This should only be used if it has been *checked* that it is safe (no 
aliases, in-place
     // arguments will no longer be live)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index b203b322ab..612e1459c8 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -1088,8 +1088,7 @@ class TIRFuseMutator : public ExprMutator {
       const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, 
old_gvar);
 
       GlobalVar new_gvar(old_gvar->name_hint);
-      UpdateStructInfo(new_gvar,
-                       
FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)));
+      UpdateStructInfo(new_gvar, GetStructInfo(prim_func));
 
       mod->Remove(old_gvar);
       updates->Add(new_gvar, prim_func);
diff --git 
a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
 
b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
index e1f45d278d..865051b0b4 100644
--- 
a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
+++ 
b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
@@ -512,13 +512,11 @@ def test_decoder_layer():
                 cls.rotary_embedding,
                 (lv9, cos_cached, sin_cached),
                 out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
-                tir_vars=R.shape([256]),
             )
             lv17 = R.call_tir(
                 cls.rotary_embedding,
                 (lv12, cos_cached, sin_cached),
                 out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
-                tir_vars=R.shape([256]),
             )
             lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape(
                 lv17, R.shape([256, 32, 128])
@@ -712,13 +710,11 @@ def test_decoder_layer():
                 cls.rotary_embedding,
                 (lv9, cos_cached, sin_cached),
                 out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", 
"S[2]"),
-                tir_vars=R.shape([256]),
             )
             lv17 = R.dist.call_tir(
                 cls.rotary_embedding,
                 (lv12, cos_cached, sin_cached),
                 out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", 
"S[2]"),
-                tir_vars=R.shape([256]),
             )
             lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = 
R.reshape(
                 lv17, R.shape([256, 32, 128])
@@ -1278,13 +1274,11 @@ def test_decoder_layer_tir():
                 cls.rotary_embedding,
                 (lv9, cos_cached, sin_cached),
                 out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
-                tir_vars=R.shape([256]),
             )
             lv17 = R.call_tir(
                 cls.rotary_embedding,
                 (lv12, cos_cached, sin_cached),
                 out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
-                tir_vars=R.shape([256]),
             )
             lv18 = R.call_tir(
                 cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128), 
dtype="float16")
@@ -1449,13 +1443,11 @@ def test_decoder_layer_tir():
                 LlamaAttentionLayerTIR.get_global_var("rotary_embedding"),
                 (lv9, cos_cached, sin_cached),
                 out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", 
"S[2]"),
-                tir_vars=R.shape([256]),
             )
             lv17 = R.dist.call_tir(
                 LlamaAttentionLayerTIR.get_global_var("rotary_embedding"),
                 (lv12, cos_cached, sin_cached),
                 out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", 
"S[2]"),
-                tir_vars=R.shape([256]),
             )
             lv18 = R.dist.call_tir(
                 LlamaAttentionLayerTIR.get_global_var("reshape1"),
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index 7deddfd28e..c0b962c3f3 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -14,15 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import pytest
+
 import tvm
 import tvm.testing
+
 from tvm import relax as rx
 from tvm import tir
-from tvm.script import relax as R
-from tvm.script import ir as I
-from tvm.script import tir as T
-from tvm.script import ir as I
+from tvm.script import ir as I, relax as R, tir as T
 
 m = tir.Var("m", "int64")
 n = tir.Var("n", "int64")
@@ -702,5 +702,511 @@ def test_pass_dltensor_arg_to_tir():
     assert rx.analysis.well_formed(Module)
 
 
+def test_call_tir_with_matching_arguments():
+    """R.call_tir is well-formed when called with matching arguments"""
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], 
"float16"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_input_ndim():
+    """Arguments to R.call_tir must have the correct dimensionality
+
+    Here, the `add_one` function expects a 1-d input tensor, but is
+    called with a 2-d tensor.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([4, 4], "float16")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], 
"float16"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_output_ndim():
+    """Output shape R.call_tir must have the correct dimensionality
+
+    Here, the `add_one` function requires a 1-d output tensor, but is
+    provided with a 2-d tensor.
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], 
"float16"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_input_shape():
+    """Arguments to R.call_tir must have the correct shape
+
+    Here, the `add_one` function expects an input tensor with 16
+    elements, but is called with an input tensor with 32 elements.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([32], "float16")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], 
"float16"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_output_shape():
+    """Output shape R.call_tir must have the correct shape
+
+    Here, the `add_one` function requires an output tensor with 16
+    elements, but is provided an output tensor with 32 elements.
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], 
"float16"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_input_dtype():
+    """Arguments to R.call_tir must have the correct dtype
+
+    Here, the `add_one` function expects an input tensor containing
+    float16 value, but is called with an input tensor containing
+    float32 values.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], 
"float16"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_output_dtype():
+    """Output shape R.call_tir must have the correct shape
+
+    Here, the `add_one` function requires an output tensor that may be
+    populated with float16 values, but is provided an output tensor
+    that may be populated with float32 elements.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], 
"float32"))
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_correct_dynamic_output_shape():
+    """Output shape R.call_tir may not be verifiable
+
+    Here, the input arguments to the `reshape` function are not
+    sufficient to infer the shape of the outputs.  This is legal,
+    since the output shape is determined by the `out_sinfo` parameter.
+
+    Inability to verify the output shape does not mean that the output
+    shape is invalid.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], 
"float16"))
+            return B
+
+        @T.prim_func
+        def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            B = T.match_buffer(B_handle, [M, N], dtype="float16")
+
+            for i, j in T.grid(M, N):
+                with T.block("compute"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi * N + vj]
+
+    assert rx.analysis.well_formed(Module)
+
+
[email protected](reason="Not supported")
+def test_call_tir_with_incorrect_dynamic_output_shape():
+    """Output shape R.call_tir may not be verifiable
+
+    Here, the input arguments to the `reshape` function are not
+    sufficient to infer the shape of the outputs.  Even though the
+    IRModule will not provide well-defined output due to the
+    out-of-bounds read from buffer A, catching this error is beyond
+    the current scope of the Relax well-formed checker.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], 
"float16"))
+            return B
+
+        @T.prim_func
+        def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            B = T.match_buffer(B_handle, [M, N], dtype="float16")
+
+            for i, j in T.grid(M, N):
+                with T.block("compute"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi * N + vj]
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_incorrect_dimensionality_of_output_shape():
+    """Dimensionality may be verified
+
+    Here, the input arguments to the `reshape` function are not
+    sufficient to infer the shape of the outputs.
+
+    Even though the output shape may not be inferred from the input
+    arguments, the output dimensionality can still be inferred from
+    the PrimFunc signature.  The IRModule below is ill-formed, because
+    the PrimFunc requires a 2-d output argument, but is provided with
+    a 3-d output argument.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], 
"float16"))
+            return B
+
+        @T.prim_func
+        def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            B = T.match_buffer(B_handle, [M, N], dtype="float16")
+
+            for i, j in T.grid(M, N):
+                with T.block("compute"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi * N + vj]
+
+    assert not rx.analysis.well_formed(Module)
+
+
[email protected](reason="Not yet supported")
+def test_call_tir_output_shape_with_mixed_static_and_dynamic():
+    """Some dimensions of the R.call_tir output shape may be verifiable
+
+    Here, the input arguments to the `reshape` function are not
+    sufficient to infer the shape of the outputs.  This is legal,
+    since the output shape is taken from the `out_sinfo` parameter.
+
+    Identifying this failure mode is not yet supported in the current
+    implementation.  This is because the output is inferred as
+    `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo`
+    is a 3-d tensor.  The mismatch in the first dimension is not yet
+    counted, because the entire tensor shape is removed by
+    `EraseToWellDefined`.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([256], "float16")):
+            B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], 
"float16"))
+            return B
+
+        @T.prim_func
+        def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            B = T.match_buffer(B_handle, [16, M, N], dtype="float16")
+
+            for i, j, k in T.grid(16, M, N):
+                with T.block("compute"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    B[vi, vj, vk] = A[vi * N * M + vj * N + vk]
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_correct_inferred_dynamic_output_shape():
+    """Some dynamic output shapes of R.call_tir may be inferred
+
+    Here, the `flatten` function is dynamic, and will flatten any 2-d
+    TIR buffer.  Even though it is dynamic, the input shapes are
+    sufficient to infer that `M==8` and `N==4`.  As a result, the
+    output shape of `[M*N]` can be inferred to be `[32]`, and the
+    shape specified in `out_sinfo` can be validated.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([8, 4], "float16")):
+            B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], 
"float16"))
+            return B
+
+        @T.prim_func
+        def flatten(A_handle: T.handle, B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            A = T.match_buffer(A_handle, [M, N], dtype="float16")
+            B = T.match_buffer(B_handle, [M * N], dtype="float16")
+
+            for i in T.grid(M * N):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi // N, vi % N]
+
+    assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_incorrect_inferred_dynamic_output_shape():
+    """Some dynamic output shapes of R.call_tir may be inferred
+
+    Here, the `flatten` function is dynamic, and will flatten any 2-d
+    TIR buffer.  Even though it is dynamic, the input shapes are
+    sufficient to infer that `M==8` and `N==4`.  As a result, the
+    output shape of `[M*N]` can be inferred to be `[32]`, and the
+    shape specified in `out_sinfo` can be validated.
+
+    This unit test is identical to the above test
+    `test_call_tir_with_correct_inferred_dynamic_output_shape`, except
+    that the output shape is explicitly specified as `[64]`, which is
+    caught as a mismatch from the expected output shape.
+
+    """
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([8, 4], "float16")):
+            B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], 
"float16"))
+            return B
+
+        @T.prim_func
+        def flatten(A_handle: T.handle, B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            A = T.match_buffer(A_handle, [M, N], dtype="float16")
+            B = T.match_buffer(B_handle, [M * N], dtype="float16")
+
+            for i in T.grid(M * N):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi // N, vi % N]
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_dtensor_arguments():
+    """R.call_tir and R.dist.call_tir share the same operation
+
+    Both `R.call_tir` and `R.dist.call_tir` produce the same
+    "relax.call_tir" operation, differing only in the StructInfo of
+    their arguments.  Normalization of "relax.call_tir" must handle
+    `R.DTensor` arguments.
+
+    """
+
+    # from tvm.script.parser import relax as R
+
+    @I.ir_module
+    class Module:
+        I.module_attrs({"device_num": 4})
+        I.module_global_infos({"mesh": [R.dist.device_mesh([4], I.Range(0, 
4))]})
+
+        @R.function
+        def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")):
+            B = R.dist.call_tir(
+                Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16", 
"mesh[0]", "S[0]")
+            )
+            return B
+
+        @T.prim_func
+        def flatten(A_handle: T.handle, B_handle: T.handle):
+            M = T.int64()
+            N = T.int64()
+            A = T.match_buffer(A_handle, [M, N], dtype="float16")
+            B = T.match_buffer(B_handle, [M * N], dtype="float16")
+
+            for i in T.grid(M * N):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = A[vi // N, vi % N]
+
+    assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_inplace_with_correct_shapes():
+    """R.call_tir_inplace is well-formed when called with matching arguments"""
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir_inplace(
+                Module.add_one,
+                A,
+                inplace_indices=[0],
+                out_sinfo=R.Tensor([16], "float16"),
+            )
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    A[vi] = A[vi] + T.float16(1.0)
+
+    assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_inplace_with_incorrect_shapes():
+    """R.call_tir_inplace is ill-formed when output shape does not match 
input"""
+
+    @I.ir_module(check_well_formed=False)
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.call_tir_inplace(
+                Module.add_one,
+                A,
+                inplace_indices=[0],
+                out_sinfo=R.Tensor([32], "float16"),
+            )
+            return B
+
+        @T.prim_func
+        def add_one(A: T.Buffer(16, "float16")):
+            for i in range(16):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    A[vi] = A[vi] + T.float16(1.0)
+
+    assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_inplace_with_some_allocated_outputs():
+    """R.call_tir_inplace may contain some non-inplace outputs"""
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")):
+            out = R.call_tir_inplace(
+                Module.add_one,
+                (A, B),
+                inplace_indices=[-1, 1],
+                out_sinfo=[
+                    R.Tensor([16], "float16"),
+                    R.Tensor([32], "float16"),
+                ],
+            )
+            return out
+
+        @T.prim_func
+        def add_one(
+            A: T.Buffer(16, "float16"),
+            B: T.Buffer(32, "float16"),
+            C: T.Buffer(16, "float16"),
+        ):
+            for i in range(32):
+                with T.block("inplace_B"):
+                    vi = T.axis.remap("S", [i])
+                    B[vi] = B[vi] + T.float16(1.0)
+
+            for i in range(16):
+                with T.block("output_C"):
+                    vi = T.axis.remap("S", [i])
+                    C[vi] = A[vi] + T.float16(1.0)
+
+    assert rx.analysis.well_formed(Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
index 64d5c73811..6005ecb0fa 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -43,6 +43,7 @@ def normalize(func: rx.Function) -> rx.Function:
     """
     Normalize the expr to fill in the checked_type_ and struct_info fields 
everywhere
     """
+
     # using a default mutator to use the BlockBuilder's normalizer,
     # which oddly differs from the Normalize pass
     @rx.expr_functor.mutator
@@ -435,9 +436,13 @@ def test_call_tir():
     @tvm.script.ir_module
     class TestCallTIR:
         @T.prim_func
-        def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), 
"int32")) -> None:
+        def addone(A_handle: T.handle, B_handle: T.handle) -> None:
+            m = T.int64()
+            n = T.int64()
+            A = T.match_buffer(A_handle, (m, n), "float32")
+            B = T.match_buffer(B_handle, (m, n), "float32")
             T.func_attr(({"global_symbol": "addone"}))
-            for i, j in T.grid(16, 16):
+            for i, j in T.grid(m, n):
                 with T.block("addone"):
                     vi, vj = T.axis.remap("SS", [i, j])
                     B[vi, vj] = A[vi, vj] + T.int32(1)
diff --git a/tests/python/relax/test_dataflow_inplace.py 
b/tests/python/relax/test_dataflow_inplace.py
index 8d5eb07c78..cd6e285de4 100644
--- a/tests/python/relax/test_dataflow_inplace.py
+++ b/tests/python/relax/test_dataflow_inplace.py
@@ -172,8 +172,8 @@ def test_alias_call_tir():
             T.func_attr({"global_symbol": "tir_id"})
             m = T.int32()
             n = T.int32()
-            A = T.match_buffer(x, (m, n))
-            B = T.match_buffer(y, (m, n))
+            A = T.match_buffer(x, (m, n), "int32")
+            B = T.match_buffer(y, (m, n), "int32")
 
             for i, j in T.grid(m, n):
                 with T.block("id"):
@@ -185,9 +185,9 @@ def test_alias_call_tir():
             T.func_attr({"global_symbol": "tir_id"})
             m = T.int32()
             n = T.int32()
-            A = T.match_buffer(x, (m, n))
-            B = T.match_buffer(y, (m, n))
-            C = T.match_buffer(z, (m, n))
+            A = T.match_buffer(x, (m, n), "int32")
+            B = T.match_buffer(y, (m, n), "int32")
+            C = T.match_buffer(z, (m, n), "int32")
 
             for i, j in T.grid(m, n):
                 with T.block("id"):
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 03a3beb2f2..7a3b65cea1 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -72,7 +72,7 @@ class Module:
             lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), 
dtype="float32"))
             lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), 
dtype="float32"))
             lv2 = R.call_tir(
-                cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), 
tir_vars=R.ShapeExpr([32])
+                cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), 
tir_vars=R.ShapeExpr([32])
             )
             gv = (lv1, lv2)
             R.output(gv)
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index d83f83f4e1..21e1d82d28 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -114,9 +114,10 @@ def test_relax_dynamo():
     with db:
         opt_model = torch.compile(model, backend=relax_dynamo())
     inp = torch.randn(10, 100)
-    tvm.testing.assert_allclose(
-        opt_model(inp).detach().numpy(), model(inp).detach().numpy(), 
rtol=1e-5, atol=1e-5
-    )
+
+    default_output = model(inp).detach().numpy()
+    optimized_output = opt_model(inp).detach().numpy()
+    tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, 
atol=1e-5)
 
 
 def test_relax_dynamo_dynamic():
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 40624790cb..6a337b34c1 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -570,10 +570,18 @@ def test_tensor_ir_op():
     @T.prim_func(private=True)
     def fused_rope(  # pylint: disable=too-many-locals
         var_qkv: T.handle,
-        offset: T.int64,
         var_q: T.handle,
         var_k: T.handle,
         var_v: T.handle,
+        # Scalar arguments must be specified after tensor arguments,
+        # including the output tensor arguments
+        #
+        # TODO(Lunderberg): Update
+        # `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue`
+        # instead of `tir_vars`, so that the order can be consistent
+        # between the function definition and the arguments in
+        # `op.tensor_ir_op`.
+        offset: T.int64,
     ):
         batch_size = T.int64()
         seq_len = T.int64()
@@ -601,7 +609,7 @@ def test_tensor_ir_op():
     @I.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: 
T.handle, var_k: T.handle, var_v: T.handle):
+        def llama_fused_rope(var_qkv: T.handle, var_q: T.handle, var_k: 
T.handle, var_v: T.handle, offset: T.int64):
             batch_size, seq_len = T.int64(), T.int64()
             qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), 
"float16")
             q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16")
@@ -669,10 +677,11 @@ def test_tensor_ir_inplace_op():
         def test(
             self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: 
Tensor, offset: int
         ):
-            tensor_expr_op_out = op.tensor_ir_op(
+            tensor_expr_op_out = op.tensor_ir_inplace_op(
                 inplace_take,
                 "inplace_take",
                 args=[embedding_table, input_ids, embedding_dst, offset],
+                inplace_indices=[2],
                 out=Tensor.placeholder(embedding_dst.shape, 
embedding_dst.dtype),
             )
             return tensor_expr_op_out
@@ -719,10 +728,11 @@ def test_tensor_ir_inplace_op():
             R.func_attr({"num_input": 4})
             cls = Expected
             with R.dataflow():
-                lv1 = R.call_tir(
+                lv1 = R.call_tir_inplace(
                     cls.inplace_take,
                     (embedding_table, input_ids, embedding_dst),
                     out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
+                    inplace_indices=[2],
                     tir_vars=R.shape([offset_1]),
                 )
                 gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
index ee2df866fb..e3274aea88 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -86,7 +86,11 @@ def test_call_tir_rewrite():
     @tvm.script.ir_module
     class TestCallTIRRewrite:
         @T.prim_func
-        def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), 
"float32")):
+        def exp(A_handle: T.handle, B_handle: T.handle):
+            m = T.int64()
+            n = T.int64()
+            A = T.match_buffer(A_handle, (m, n), "float32")
+            B = T.match_buffer(B_handle, (m, n), "float32")
             T.evaluate(0)
 
         @R.function
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 65970d6455..0ddf985ec4 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -277,18 +277,26 @@ def 
test_tracking_through_externally_exposed_func(provide_entry_func_name):
 
 def test_unused_relax_func_symbolic_shape():
     # Test with relax function w/ symbolic shape.
-    @tvm.script.ir_module
+    @tvm.script.ir_module(check_well_formed=False)
     class InputModule:
         @T.prim_func
-        def tir_add(
-            x: T.Buffer((16, 16), "float32"),
-            y: T.Buffer((16, 16), "float32"),
-            z: T.Buffer((16, 16), "float32"),
+        def tir_matmul(
+            x_handle: T.handle,
+            y_handle: T.handle,
+            z_handle: T.handle,
         ) -> None:
-            for i, j in T.grid(16, 16):
-                with T.block("add"):
-                    vi, vj = T.axis.remap("SS", [i, j])
-                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
+            x = T.match_buffer(x_handle, (m, n), "float32")
+            y = T.match_buffer(y_handle, (n, k), "float32")
+            z = T.match_buffer(z_handle, (m, k), "float32")
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        z[vi, vj] = 0.0
+                    z[vi, vj] = z[vi, vj] + x[vi, vk] * y[vk, vj]
 
         @R.function(private=True)
         def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", 
"k"), "float32")):
@@ -298,7 +306,7 @@ def test_unused_relax_func_symbolic_shape():
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")):
             m, k = T.int64(), T.int64()
-            gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((m + 1, k), 
dtype="float32"))
+            gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
             return gv0
 
     mod = InputModule
@@ -306,7 +314,7 @@ def test_unused_relax_func_symbolic_shape():
 
     new_mod = DeadCodeElimination()(mod)
     assert check_if_func_exists(new_mod, "main")
-    assert check_if_func_exists(new_mod, "tir_add")
+    assert check_if_func_exists(new_mod, "tir_matmul")
     assert not check_if_func_exists(new_mod, "unused_func")
 
 
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 17bf586132..9ad66bec01 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -875,7 +875,7 @@ def test_layer_norm_silu():
         def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 
64), "float32"), var: R.Tensor((64, 64), "float32")):
             cls = Module
             with R.dataflow():
-                gv0 = R.call_tir(cls.layer_norm, (x, mean, var), 
out_sinfo=R.Tensor((1, 512, 64, 64)))
+                gv0 = R.call_tir(cls.layer_norm, (x, mean, var), 
out_sinfo=R.Tensor((1, 512, 64, 64), 'float32'))
                 gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512, 
64, 64), "float32"))
                 R.output(gv1)
             return gv1
@@ -955,7 +955,7 @@ def test_layer_norm_silu():
             R.func_attr({"Primitive": 1})
             cls = Expected
             with R.dataflow():
-                gv0 = R.call_tir(cls.layer_norm, (x, mean, var), 
out_sinfo=R.Tensor((1, 512, 64, 64)))
+                gv0 = R.call_tir(cls.layer_norm, (x, mean, var), 
out_sinfo=R.Tensor((1, 512, 64, 64), 'float32'))
                 gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512, 
64, 64), dtype="float32"))
                 R.output(gv)
             return gv
@@ -1452,7 +1452,7 @@ def test_partially_used_tuple_param():
                 R.Tensor((2,), "float32"),
                 R.Tensor((2,), "float32"),
                 R.Tensor((2,), "float32"),
-            )
+            ),
         ):
             with R.dataflow():
                 x0 = x[0]
@@ -1486,7 +1486,7 @@ def test_partially_used_tuple_param():
                 R.Tensor((2,), dtype="float32"),
                 R.Tensor((2,), dtype="float32"),
                 R.Tensor((2,), dtype="float32"),
-            )
+            ),
         ) -> R.Tensor((2,), dtype="float32"):
             cls = Expected
             with R.dataflow():
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 1582526042..a07875fcda 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -696,10 +696,10 @@ def test_ignore_call_tir():
     class Conv2dReLUCallTIR:
         @T.prim_func
         def relu(
-            data: T.Buffer((64, 64, 56, 56), "float32"),
-            out: T.Buffer((64, 64, 56, 56), "float32"),
+            data: T.Buffer((1, 64, 56, 56), "float32"),
+            out: T.Buffer((1, 64, 56, 56), "float32"),
         ):
-            for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+            for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56):
                 with T.block("root"):
                     i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                     out[i, j, k, l] = T.max(data[i, j, k, l], 0.0)
@@ -714,7 +714,7 @@ def test_ignore_call_tir():
                 relu1 = R.call_tir(
                     Conv2dReLUCallTIR.relu,
                     (conv1,),
-                    R.Tensor((64, 64, 56, 56), "float32"),
+                    R.Tensor((1, 64, 56, 56), "float32"),
                 )
                 R.output(relu1)
 
@@ -724,11 +724,11 @@ def test_ignore_call_tir():
     class Conv2dReLUCallTIR_partitioned:
         @T.prim_func
         def relu(
-            data: T.Buffer((64, 64, 56, 56), "float32"),
-            out: T.Buffer((64, 64, 56, 56), "float32"),
+            data: T.Buffer((1, 64, 56, 56), "float32"),
+            out: T.Buffer((1, 64, 56, 56), "float32"),
         ):
             # with T.block("root"):
-            for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+            for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56):
                 with T.block("root"):
                     i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                     T.reads(data[i, j, k, l])
@@ -754,7 +754,7 @@ def test_ignore_call_tir():
         def main(
             data: R.Tensor((1, 64, 56, 56), dtype="float32"),
             weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
-        ) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
             cls = Conv2dReLUCallTIR_partitioned
             with R.dataflow():
                 lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
cls.fused_relax_nn_conv2d(
@@ -763,7 +763,7 @@ def test_ignore_call_tir():
                 relu1 = R.call_tir(
                     cls.relu,
                     (lv,),
-                    out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"),
+                    out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"),
                 )
                 R.output(relu1)
             return relu1
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index 278ac825f7..87a5698f1b 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -43,7 +43,7 @@ def test_lazy_transform_params():
         def main_transform_params(
             params: R.Tuple(
                 R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
-            )
+            ),
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
         ):
@@ -124,7 +124,7 @@ def test_get_item_only():
         def main_transform_params(
             params: R.Tuple(
                 R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
-            )
+            ),
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
         ):
@@ -209,7 +209,7 @@ def test_extra_get_item_params():
         def main_transform_params(
             params: R.Tuple(
                 R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
-            )
+            ),
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
         ):
@@ -298,7 +298,7 @@ def test_extra_set_item_params():
         def main_transform_params(
             params: R.Tuple(
                 R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 
3), dtype="float32")
-            )
+            ),
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), 
dtype="float32")
         ):
@@ -441,8 +441,8 @@ def test_lazy_transform_params_with_symbolic_vars():
         @T.prim_func(private=True)
         def slice_buffer(
             Input: T.Buffer((16, 16), "float32"),
-            slice_index: T.int64,
             Output: T.Buffer(16, "float32"),
+            slice_index: T.int64,
         ):
             for i in T.grid(16):
                 with T.block("slice_buffer"):
@@ -479,8 +479,8 @@ def test_lazy_transform_params_with_symbolic_vars():
         @T.prim_func(private=True)
         def slice_buffer(
             Input: T.Buffer((16, 16), "float32"),
-            slice_index: T.int64,
             Output: T.Buffer(16, "float32"),
+            slice_index: T.int64,
         ):
             for i in T.grid(16):
                 with T.block("slice_buffer"):
@@ -511,7 +511,7 @@ def test_param_shape_symbolic():
             params: R.Tuple(
                 R.Tensor((3, "ic", 3, 3), dtype="float32"),
                 R.Tensor((16, 16, 3, 3), dtype="float32"),
-            )
+            ),
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 
3), dtype="float32")
         ):
@@ -637,7 +637,7 @@ def test_output():
             params: R.Tuple(
                 R.Tensor((3, "ic", 3, 3), dtype="float32"),
                 R.Tensor((16, 16, 3, 3), dtype="float32"),
-            )
+            ),
         ) -> R.Tuple(
             R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 
3), dtype="float32")
         ):
@@ -691,7 +691,7 @@ def test_duplicate_outputs():
     class Before:
         @R.function
         def main_transform_params(
-            params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], 
dtype="int32"))
+            params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], 
dtype="int32")),
         ):
             R.func_attr({"relax.force_pure": True})
             param0 = params[0]
@@ -966,7 +966,7 @@ def test_get_item_callback_dynamic_shape():
     class Expected:
         @R.function
         def transform_params(
-            fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object)
+            fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object),
         ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, 
dtype="float32")):
             R.func_attr({"num_input": 1})
             m = T.int64()
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index f7befd3b88..5a7d76d8fe 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -252,11 +252,15 @@ def test_reshape_dynamic_shape():
                         ]
 
         @R.function
-        def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), 
dtype="float32"):
+        def main(
+            x: R.Tensor((8, 16, 128), dtype="float16")
+        ) -> R.Tensor((1, 8, 16, 128), dtype="float16"):
             cls = Module
             with R.dataflow():
-                y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 
3), dtype="float32"))
-                z = R.add(y, R.const(1, "float32"))
+                y = R.call_tir(
+                    cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128), 
dtype="float16")
+                )
+                z = R.add(y, R.const(1, "float16"))
                 R.output(z)
             return z
 
@@ -290,10 +294,14 @@ def test_reshape_dynamic_shape():
                         ]
 
         @R.function
-        def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), 
dtype="float32"):
+        def main(
+            x: R.Tensor((8, 16, 128), dtype="float16")
+        ) -> R.Tensor((1, 8, 16, 128), dtype="float16"):
             with R.dataflow():
-                y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x, 
R.shape([2, 4, 3]))
-                z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1, 
"float32"))
+                y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape(
+                    x, R.shape([1, 8, 16, 128])
+                )
+                z: R.Tensor((1, 8, 16, 128), dtype="float16") = R.add(y, 
R.const(1, "float16"))
                 R.output(z)
             return z
 
@@ -383,7 +391,7 @@ def test_tuple_get_reshape():
                 R.Tensor((2, 4096, 320), dtype="float16"),
                 R.Tensor((2, 4096, 320), dtype="float16"),
                 R.Tensor((2, 4096, 320), dtype="float16"),
-            )
+            ),
         ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
             cls = Module
             with R.dataflow():
@@ -444,7 +452,7 @@ def test_tuple_get_reshape():
                 R.Tensor((2, 4096, 320), dtype="float16"),
                 R.Tensor((2, 4096, 320), dtype="float16"),
                 R.Tensor((2, 4096, 320), dtype="float16"),
-            )
+            ),
         ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
             with R.dataflow():
                 lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
@@ -735,7 +743,6 @@ def test_rewrite_dynamic_reshape():
             z_handle: T.handle,
             N: T.int64,
         ):
-
             y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32")
             y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32")
             z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32")
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index ea99d49270..64f2efd4af 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -77,7 +77,7 @@ def test_mismatch_cast_dims_and_ndim():
 
         @R.function
         def f(
-            x: R.Tensor((2, 3), "float32", ndim=3)
+            x: R.Tensor((2, 3), "float32", ndim=3),
         ):  # error: ndim and the shape dims are mismatch
             return x
 
@@ -961,11 +961,11 @@ def test_call_tir_with_tir_var():
     class Module:
         @R.function
         def main(
-            dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", 
"float32"))
+            dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",), 
"float32")
         ) -> R.Tensor(("n * 2",), "float32"):
             n = T.int64()
             cls = Module
-            y = R.call_tir(cls.copy, (x,), R.Tensor(((n * 2,)), 
dtype="float32"), tir_vars=(n,))
+            y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"), 
tir_vars=(n,))
             return y
 
         @T.prim_func
@@ -2171,7 +2171,9 @@ def test_macro_hygienic():
     @R.function(private=True)
     def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
         alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
-            R.shape([4, 4]), R.dtype("float32"), R.prim_value(2)  # Make sure 
prim_value is 2
+            R.shape([4, 4]),
+            R.dtype("float32"),
+            R.prim_value(2),  # Make sure prim_value is 2
         )
         shape: R.Shape([4, 4]) = R.shape_of(alloc)
         shape_1: R.Shape([4, 4]) = shape
@@ -2203,7 +2205,9 @@ def test_macro_non_hygienic():
     @R.function(private=True)
     def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
         alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
-            R.shape([4, 4]), R.dtype("float32"), R.prim_value(1)  # Make sure 
prim_value is 1
+            R.shape([4, 4]),
+            R.dtype("float32"),
+            R.prim_value(1),  # Make sure prim_value is 1
         )
         shape: R.Shape([4, 4]) = R.shape_of(alloc)
         shape_1: R.Shape([4, 4]) = shape
@@ -2372,7 +2376,6 @@ def 
test_conditional_may_use_symbolic_variables_from_function_scope():
         B: R.Tensor(["N"], "float32"),
         cond: R.Prim("bool"),
     ) -> R.Tensor(["N"], "float32"):
-
         N = T.int64()
 
         if cond:
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index 30fd06d4f1..ecf33aa9da 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -988,8 +988,10 @@ def test_multi_systemlib(exec_mode):
         I.module_attrs({"system_lib_prefix": "libA_"})
 
         @T.prim_func
-        def tir_init(x: T.Buffer((2), "float32")) -> None:
-            for i in range(2):
+        def tir_init(x_handle: T.handle):
+            N = T.int64()
+            x = T.match_buffer(x_handle, [N], "float32")
+            for i in range(N):
                 x[i] = T.float32(0)
 
         @R.function
@@ -1003,8 +1005,10 @@ def test_multi_systemlib(exec_mode):
         I.module_attrs({"system_lib_prefix": "libB_"})
 
         @T.prim_func
-        def tir_init(x: T.Buffer((2), "float32")) -> None:
-            for i in range(2):
+        def tir_init(x_handle: T.handle):
+            N = T.int64()
+            x = T.match_buffer(x_handle, [N], "float32")
+            for i in range(N):
                 x[i] = T.float32(1)
 
         @R.function

Reply via email to