This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eb5458e0e9 [Relax] Allow R.Prim('bool') in relax::If and assert_op  
(#16642)
eb5458e0e9 is described below

commit eb5458e0e9b93001bb7e4a69d7d4e393cd55c933
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 28 16:53:47 2024 -0500

    [Relax] Allow R.Prim('bool') in relax::If and assert_op  (#16642)
    
    * [TIR][Analysis] Implemented tir.analysis.is_pure_function
    
    This commit introduces two related utilities,
    `tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
    In contrast to the existing `tvm::tir::SideEffect`, which checks for
    side effects on a for a `PrimExpr`, `is_pure_function` checks for side
    effects for the function as a whole.
    
    * [Transform] Implement relax.transform.ComputePrimValue
    
    Prior to this commit, while expressions of type `DataType::Int(64)`
    could be computed in the `relax.transform.VMShapeLower`, expressions
    of any other type could not.  This commit introduces
    `relax.transform.ComputePrimValue`, which produces `PrimFunc`
    subroutines to compute `PrimExpr` values of any dtype.
    
    This functionality will allow boolean values to be computed based on
    the symbolic values known at runtime.
    
    * [Relax] Allow R.Prim('bool') in relax::If and assert_op
    
    Prior to this commit, the condition used for `relax::If` node and the
    `"relax.assert_op"` operator was required to be a scalar tensor.  This
    made it difficult to alter behavior based on a runtime shape
    parameter.  For example, delegating to a vectorized implementation
    based on a whether a tensor shape is divisible by the vector size.
    
    This commit adds support for expressions of type `R.Prim('bool')` as
    the conditional for `relax::If` and `"relax.assert_op"`, to allow
    these use cases.
    
    * Lint fix
---
 include/tvm/tir/analysis.h                         |  15 +-
 python/tvm/error.py                                |   1 +
 python/tvm/relax/op/base.py                        |  44 +++--
 python/tvm/relax/pipeline.py                       |   1 +
 python/tvm/relax/transform/__init__.py             |   1 +
 python/tvm/relax/transform/transform.py            |  19 ++
 python/tvm/script/ir_builder/relax/ir.py           |  15 +-
 python/tvm/script/parser/tir/parser.py             |  33 +++-
 python/tvm/tir/analysis/analysis.py                |  10 ++
 src/relax/analysis/struct_info_analysis.cc         |   6 +-
 src/relax/backend/vm/vm_shape_lower.cc             |   1 +
 src/relax/op/tensor/inspect.cc                     |   4 +-
 src/relax/transform/compute_prim_value.cc          |  94 ++++++++++
 src/relax/transform/dataflow_inplace.cc            |  45 ++---
 src/relax/utils.cc                                 |  17 +-
 src/tir/analysis/is_pure_function.cc               |  97 ++++++++++
 src/tir/ir/function.cc                             |  43 +++++
 src/tir/ir/specialize.cc                           |  10 +-
 src/tir/transforms/renew_defs.cc                   |   6 +-
 tests/python/relax/test_analysis_well_formed.py    |  46 +++++
 .../relax/test_backend_transform_shape_lower.py    |  84 +++++++++
 tests/python/relax/test_relax_operators.py         | 195 ++++++++++++++++-----
 tests/python/relax/test_transform.py               |  12 +-
 .../relax/test_transform_compute_prim_value.py     | 104 +++++++++++
 tests/python/relax/test_tvmscript_parser.py        | 147 +++++++++++++++-
 tests/python/relax/test_vm_codegen_tir.py          |   2 +-
 .../test_tir_analysis_is_pure_function.py          | 104 +++++++++++
 tests/python/tir-base/test_tir_specialize.py       |  27 ++-
 .../python/tvmscript/test_tvmscript_parser_tir.py  | 109 ++++++++++++
 29 files changed, 1154 insertions(+), 138 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index c4ae5d573b..96459f25ec 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -117,13 +117,26 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr, const Array<Var>& defs);
 
 /*!
- * \brief Analyze the side effect
+ * \brief Analyze the side effect of an expression
  * \param expr The expression to be checked.
  *
  * \return CallEffectKind, can be kPure, kReadState or kUpdateState
  */
 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
 
+/*!
+ * \brief Analyze the side effect of a function
+ *
+ * \param func The expression to be checked.
+ *
+ * \param assert_on_error If true, an error will be thrown for an
+ *    impure function.  If false (default), the purity of the PrimFunc
+ *    will be returned.
+ *
+ * \return The purity of the function
+ */
+TVM_DLL bool IsPureFunction(const PrimFunc& func, bool assert_on_error = 
false);
+
 /*!
  * \brief Whether the given Stmt uses any var in the given variable set.
  * \param stmt The Stmt to be checked.
diff --git a/python/tvm/error.py b/python/tvm/error.py
index cc0180593d..6bf9b16850 100644
--- a/python/tvm/error.py
+++ b/python/tvm/error.py
@@ -54,6 +54,7 @@ register_error("TypeError", TypeError)
 register_error("AttributeError", AttributeError)
 register_error("KeyError", KeyError)
 register_error("IndexError", IndexError)
+register_error("AssertionError", AssertionError)
 
 
 @register_error
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 3effec242d..756d250c16 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -503,19 +503,26 @@ def relax_assert_op(condition: tvm.Object, format_str: 
str, *format_args: tvm.Ob
             f"The format string argument to assert must be a string, given 
{type(format_str)})"
         )
 
-    # should be guaranteed by the type system
-    if not isinstance(condition, tvm.nd.NDArray):
-        raise ValueError(f"The condition must be an NDArray, but given a 
{type(condition)}.")
-
-    # may happen if the original program had unknown shape or dtype for the 
tensor's type
-    dtype = condition.dtype
-    if dtype != "bool":
-        raise ValueError(f"The condition must be a bool scalar, but given a 
{dtype} tensor")
-    shape = condition.shape
-    if len(shape) != 0:
-        raise ValueError(f"The condition must be a scalar, but it has a shape 
of {shape}")
-
-    val = condition.numpy()
+    if isinstance(condition, (bool, int)):
+        val = condition
+    elif isinstance(condition, tvm.nd.NDArray):
+        # may happen if the original program had unknown shape or dtype for 
the tensor's type
+        dtype = condition.dtype
+        if dtype != "bool":
+            raise ValueError(f"The condition must be a bool scalar, but given 
a {dtype} tensor")
+        shape = condition.shape
+        if len(shape) != 0:
+            raise ValueError(f"The condition must be a scalar, but it has a 
shape of {shape}")
+
+        val = condition.numpy()
+
+    else:
+        # should be guaranteed by the type system
+        raise ValueError(
+            f"The condition for relax assert must be a bool, int, or NDArray, "
+            f"but received a {type(condition)}."
+        )
+
     if not val:
         error_message = "Assertion Failed"
         if format_args or format_str != "":
@@ -528,7 +535,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, 
*format_args: tvm.Ob
 
 
 def assert_op(
-    condition: Expr,
+    condition: Union[Expr, PrimExpr],
     format_args: Optional[Union[Expr, List[Expr]]] = None,
     format: Union[str, Expr] = "",
 ) -> Expr:
@@ -538,7 +545,7 @@ def assert_op(
 
     Parameters
     ----------
-    condition: Expr
+    condition: Union[Expr, PrimExpr]
         The assertion condition.
 
     format_args: Optional[Union[Expr, List[Expr]]]
@@ -552,12 +559,17 @@ def assert_op(
     result : Expr
         A Call to the Relax assert operation.
     """
+    if not isinstance(condition, Expr):
+        condition = tvm.relax.PrimValue(condition)
+
     if format_args is None:
         format_args = []
-    if isinstance(format_args, Expr):  # type: ignore
+    elif isinstance(format_args, Expr):
         format_args = [format_args]
+
     if isinstance(format, str):
         format = StringImm(format)
+
     return _ffi_api.assert_op(condition, format_args, format)  # type: ignore
 
 
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 474833bdfd..36ba46a1a5 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -92,6 +92,7 @@ def default_build_pipeline():
                 transform.LowerAllocTensor(),
                 transform.KillAfterLastUse(),
                 transform.VMBuiltinLower(),
+                transform.ComputePrimValue(),
                 transform.VMShapeLower(),
                 transform.AttachGlobalSymbol(),
             ],
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 5f10c39d82..11e301c26c 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -28,6 +28,7 @@ from .transform import (
     CallTIRRewrite,
     CanonicalizeBindings,
     CombineParallelMatmul,
+    ComputePrimValue,
     ConvertLayout,
     ConvertToDataflow,
     DataflowBlockPass,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index ef10f5791d..dbc35d48d3 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -486,6 +486,25 @@ def KillAfterLastUse() -> tvm.ir.transform.Pass:
     return _ffi_api.KillAfterLastUse()  # type: ignore
 
 
+def ComputePrimValue() -> tvm.ir.transform.Pass:
+    """Compute all R.prim_value instances
+
+    While high-level relax can include expressions in terms of its
+    symbolic variables, these expressions cannot natively be computed
+    within relax.  In order to provide values for symbolic expressions
+    (e.g. `R.prim_value(N*N)`, where `N` is a symbolic variable), this
+    pass generates a PrimFunc in which the expression can be computed.
+    The relax graph is then updated to include a call to that
+    PrimFunc, in place of the original `R.prim_value(expr)`.
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+
+    """
+    return _ffi_api.ComputePrimValue()  # type: ignore
+
+
 def VMBuiltinLower() -> tvm.ir.transform.Pass:
     """Lowering generic intrinsic to VM intrinsics.
 
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 3e1927290d..6dbf5c5dfd 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -511,18 +511,25 @@ def SeqExpr() -> frame.SeqExprFrame:  # pylint: 
disable=invalid-name
 ############################# If Then Else #############################
 
 
-def If(condition: Expr) -> frame.IfFrame:  # pylint: disable=invalid-name
+def If(condition: Union[Expr, PrimExpr]) -> frame.IfFrame:  # pylint: 
disable=invalid-name
     """Create an if frame.
+
     Parameters
     ----------
-    condition : Expr
-        The condition of if statement, executes the true branch if the 
condition is true,
-        otherwise jump into the false branch.
+    condition : Union[Expr, PrimExpr]
+
+        The condition of if statement, executes the true branch if the
+        condition is true, otherwise jump into the false branch.
+
     Returns
     -------
     res : frame.IfFrame
         The result IfFrame.
+
     """
+    if not isinstance(condition, Expr):
+        condition = relax.PrimValue(condition)
+
     return _ffi_api.If(condition)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
diff --git a/python/tvm/script/parser/tir/parser.py 
b/python/tvm/script/parser/tir/parser.py
index 0f3f3de60f..679ae4e8ad 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -537,12 +537,31 @@ def visit_tvm_declare_function(self: Parser, node: 
doc.FunctionDef) -> GlobalVar
         The doc AST return node.
     """
 
-    ret_type = None
-    if node.returns is not None:
-        ret_type = self.eval_expr(node.returns)
-        if callable(ret_type):
-            ret_type = PrimType(ret_type().dtype)
+    supplied_annotation = self.function_annotations
+    func_annotation = supplied_annotation.get(node.name, {})
 
-    # Only ret_type is needed for func_signature.
-    func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
+    ret_type = None
+    with self.var_table.with_frame():
+        if node.returns is not None:
+            ret_type = self.eval_expr(node.returns)
+            if callable(ret_type):
+                ret_type = PrimType(ret_type().dtype)
+
+        arg_annotations = []
+        for arg in node.args.args:
+            if arg.annotation is None:
+                self.report_error(arg, "Type annotation required for function 
parameters.")
+            try:
+                ann = self.eval_expr(arg.annotation)
+                if callable(ann):
+                    ann = ann()
+            except Exception:  # pylint: disable=broad-except
+                ann = func_annotation.get(arg.arg, None)
+                if ann is None:
+                    raise
+
+            IRBuilder.name(arg.arg, ann)
+            arg_annotations.append(ann)
+
+    func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type)
     return I.decl_function(node.name, func_signature)
diff --git a/python/tvm/tir/analysis/analysis.py 
b/python/tvm/tir/analysis/analysis.py
index 8d7e81d7d0..67eb7471d2 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -417,3 +417,13 @@ def get_vtcm_compaction_passes() -> 
List[tvm.transform.Pass]:
         returns list of passes
     """
     return _ffi_api.get_vtcm_compaction_passes()  # type: ignore # pylint: 
disable=no-member
+
+
+def is_pure_function(func: PrimFunc) -> bool:
+    """Checks if the function is a pure function"""
+    return _ffi_api.is_pure_function(func, False)  # type: ignore # pylint: 
disable=no-member
+
+
+def assert_pure_function(func: PrimFunc) -> bool:
+    """Asserts that the function is a pure function"""
+    return _ffi_api.is_pure_function(func, True)  # type: ignore # pylint: 
disable=no-member
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index b1932f9b5d..08e2acfbd0 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -840,8 +840,10 @@ class CallRetStructInfoDeriver : public 
StructInfoBaseChecker {
     auto params = finfo->params.value();
     if (params.size() != call->args.size()) {
       ctx->ReportFatal(Diagnostic::Error(call->span)
-                       << "number of arguments and parameters mismatch:"
-                       << " expected " << params.size() << ", given " << 
call->args.size());
+                       << "Number of arguments and parameters mismatch:"
+                       << " Function " << call->op << " has struct info " << 
finfo
+                       << " and accepts " << params.size() << " parameters, 
but was called with "
+                       << call->args.size() << " arguments (" << call->args << 
")");
     }
     // Visit each param arg pair, check and populate the var map
     for (size_t i = 0; i < params.size(); ++i) {
diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index 06c2e31767..8dca06c840 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -85,6 +85,7 @@ class PrimExprSlotCollector : public ExprVisitor, public 
StructInfoVisitor {
       collector.VisitExpr(param);
     }
     collector.VisitExpr(func->body);
+    collector.VisitStructInfo(func->ret_struct_info);
   }
 
  private:
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index 186fc9fa86..3772e530ed 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -107,7 +107,7 @@ tir::PrimFunc 
GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType
 
   FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)},
                        PrimStructInfo(field_dtype));
-  UpdateStructInfo(func, sinfo);
+  func->struct_info_ = sinfo;
 
   return func;
 }
@@ -338,7 +338,7 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const 
Call& call) {
     FuncStructInfo sinfo(
         {TensorStructInfo(DataType::Void(), kUnknownNDim), 
PrimStructInfo(axis->dtype)},
         PrimStructInfo(field_dtype));
-    UpdateStructInfo(func, sinfo);
+    func->struct_info_ = sinfo;
     return func;
   }();
 
diff --git a/src/relax/transform/compute_prim_value.cc 
b/src/relax/transform/compute_prim_value.cc
new file mode 100644
index 0000000000..9fe2a3a06f
--- /dev/null
+++ b/src/relax/transform/compute_prim_value.cc
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+class PrimValueComputeInjector : public ExprMutator {
+ public:
+  IRModule Finalize() const { return builder_->Finalize(); }
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const PrimValueNode* op) override {
+    auto node = Downcast<PrimValue>(ExprMutator::VisitExpr_(op));
+
+    if (node->value->IsInstance<tir::IntImmNode>() || 
node->value->IsInstance<tir::VarNode>()) {
+      return node;
+    }
+
+    auto ret_dtype = node->value->dtype;
+    auto param_vars = tir::UndefinedVars(node->value);
+    tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), 
{node->value}));
+
+    tir::PrimFunc func(param_vars, body, PrimType(ret_dtype));
+    func = tir::RenewDefs(func);
+
+    auto callee = builder_->AddFunction(func, "compute_symbolic_expr");
+
+    return relax::Call(callee, param_vars.Map([](const tir::Var& tir_var) -> 
relax::Expr {
+      return relax::PrimValue(tir_var);
+    }));
+  }
+};
+
+}  // namespace
+
+namespace transform {
+
+Pass ComputePrimValue() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) -> IRModule {
+    PrimValueComputeInjector mutator;
+
+    IRModule updates;
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto func = base_func.as<Function>()) {
+        auto updated = Downcast<Function>(mutator(func.value()));
+        if (!updates.same_as(base_func)) {
+          updates->Add(gvar, updated);
+        }
+      }
+    }
+
+    if (updates->functions.size()) {
+      auto write_ptr = mod.CopyOnWrite();
+      write_ptr->Update(updates);
+      write_ptr->Update(mutator.Finalize());
+    }
+
+    return mod;
+  };
+  return CreateModulePass(pass_func, 0, "ComputePrimValue", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/dataflow_inplace.cc 
b/src/relax/transform/dataflow_inplace.cc
index 755c5dbab4..0912981775 100644
--- a/src/relax/transform/dataflow_inplace.cc
+++ b/src/relax/transform/dataflow_inplace.cc
@@ -877,10 +877,12 @@ class ModuleInplaceTransformer : public ExprMutator {
     auto inline_legal_op_name = legal_op->name_hint + "_inplace";
 
     auto mod = builder_->GetContextIRModule();
-    auto legal_primfunc = Downcast<tir::PrimFunc>(mod->Lookup(legal_op));
-    auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite();
+    auto old_primfunc = Downcast<tir::PrimFunc>(mod->Lookup(legal_op));
+
+    tir::Stmt new_body = old_primfunc->body;
+
     size_t num_outs = inplace_indices.size();
-    size_t num_params = legal_primfunc->params.size();
+    size_t num_params = old_primfunc->params.size();
 
     // the replacement we must make:
     // 1. For each output var, replace its corresponding buffers with the 
corresponding inplace
@@ -893,42 +895,43 @@ class ModuleInplaceTransformer : public ExprMutator {
     Map<tir::Var, tir::Var> var_subst_map;
     for (size_t i = 0; i < num_outs; i++) {
       // we will substitute output i with the corresponding param indicated by 
inplace indices
-      auto output_var = legal_primfunc->params[num_params - num_outs + i];
-      auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()];
+      auto output_var = old_primfunc->params[num_params - num_outs + i];
+      auto inplace_var = old_primfunc->params[inplace_indices[i].IntValue()];
       var_subst_map.Set(output_var, inplace_var);
 
       // also do the same with the buffer vars
-      auto output_buffer = legal_primfunc->buffer_map.at(output_var);
-      auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var);
+      auto output_buffer = old_primfunc->buffer_map.at(output_var);
+      auto inplace_buffer = old_primfunc->buffer_map.at(inplace_var);
       var_subst_map.Set(output_buffer->data, inplace_buffer->data);
       buffer_subst_map.Set(output_buffer, inplace_buffer);
     }
 
     // apply substitutions
-    legal_primfunc_cow->body = RemapBuffers(legal_primfunc->body, 
buffer_subst_map);
-    legal_primfunc_cow->body = tir::Substitute(
-        legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> 
Optional<PrimExpr> {
-          if (var_subst_map.count(v)) {
-            return var_subst_map.at(v);
-          }
-          return Optional<PrimExpr>();
-        });
+    new_body = RemapBuffers(new_body, buffer_subst_map);
+    new_body = tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) 
-> Optional<PrimExpr> {
+      if (var_subst_map.count(v)) {
+        return var_subst_map.at(v);
+      }
+      return Optional<PrimExpr>();
+    });
 
     // remove the now-unused outputs from the buffer map
-    auto buffer_map = legal_primfunc->buffer_map;
+    auto new_buffer_map = old_primfunc->buffer_map;
     for (size_t i = 0; i < num_outs; i++) {
-      buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]);
+      new_buffer_map.erase(old_primfunc->params[num_params - num_outs + i]);
     }
-    legal_primfunc_cow->buffer_map = buffer_map;
 
     // now get rid of the last num_outputs arguments
     // (couldn't do earlier or else it would have thrown off the indexing)
-    legal_primfunc_cow->params = Array<tir::Var>(
-        legal_primfunc->params.begin(), legal_primfunc->params.begin() + 
(num_params - num_outs));
+    Array<tir::Var> new_params(old_primfunc->params.begin(),
+                               old_primfunc->params.begin() + (num_params - 
num_outs));
+
+    tir::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type, 
new_buffer_map,
+                               old_primfunc->attrs, old_primfunc->span);
 
     // note: this might be a good time to get rid of the old legalized 
function, but we don't do it
     // now because later ops might need the same one. Instead, we will clean 
up at the end
-    auto new_gv = builder_->AddFunction(legal_primfunc, inline_legal_op_name);
+    auto new_gv = builder_->AddFunction(new_primfunc, inline_legal_op_name);
 
     // update the call (change the op, update the argument, change the attrs)
     legalized_call_cow->op = call_tir_inplace_op;
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index efb2d02204..a15ee79fac 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -220,12 +220,21 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
 
 bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
                       bool permit_unknown_dtype) {
-  const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
-  if (!tt) {
+  DataType dtype;
+  int ndim;
+
+  if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+    dtype = tensor->dtype;
+    ndim = tensor->ndim;
+  } else if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
+    dtype = prim->dtype;
+    ndim = 0;
+  } else {
     return false;
   }
-  bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && 
tt->dtype.is_void());
-  bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1);
+
+  bool correct_dtype = dtype.is_bool() || (permit_unknown_dtype && 
dtype.is_void());
+  bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1);
   return correct_dtype && correct_rank;
 }
 
diff --git a/src/tir/analysis/is_pure_function.cc 
b/src/tir/analysis/is_pure_function.cc
new file mode 100644
index 0000000000..c9934c4bcf
--- /dev/null
+++ b/src/tir/analysis/is_pure_function.cc
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file is_pure_function.cc
+ * \brief PrimFunc purity analysis
+ */
+#include <tvm/ir/op.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../ir/tir_visitor_with_path.h"
+
+namespace tvm {
+namespace tir {
+
+namespace {
+class PurityChecker : TIRVisitorWithPath {
+ public:
+  static bool Check(const PrimFunc& func, bool assert_on_error) {
+    PurityChecker visitor(assert_on_error);
+    visitor(func);
+    return visitor.is_pure_;
+  }
+
+ private:
+  explicit PurityChecker(bool assert_on_error) : 
assert_on_error_(assert_on_error) {}
+
+  void VisitStmt_(const AllocateNode* op, ObjectPath path) override {
+    internal_allocations_.insert(op->buffer_var);
+    TIRVisitorWithPath::VisitStmt_(op, path);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op, ObjectPath path) override {
+    TIRVisitorWithPath::VisitStmt_(op, path);
+
+    if (!internal_allocations_.count(op->buffer->data)) {
+      is_pure_ = false;
+      LOG_IF(FATAL, assert_on_error_) << "AssertionError: "
+                                      << "Pure functions must not write to 
buffers, "
+                                      << ", but function contains store to " 
<< op->buffer
+                                      << op->indices << " of value " << 
op->value;
+    }
+  }
+
+  void VisitExpr_(const CallNode* call, ObjectPath path) override {
+    TIRVisitorWithPath::VisitExpr_(call, path);
+
+    static auto op_call_effect = 
Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+    CallEffectKind effect = [&]() {
+      if (auto opt = call->op.as<Op>()) {
+        return static_cast<CallEffectKind>(op_call_effect[opt.value()]->value);
+      } else {
+        return CallEffectKind::kOpaque;
+      }
+    }();
+
+    if (effect == CallEffectKind::kUpdateState || effect == 
CallEffectKind::kOpaque) {
+      is_pure_ = false;
+      LOG_IF(FATAL, assert_on_error_)
+          << "AssertionError: "
+          << "Pure functions must not contain calls to impure operators, "
+          << "but " << GetRef<PrimExpr>(call) << " calls operator " << call->op
+          << ", which has side effect " << effect;
+    }
+  }
+
+  bool assert_on_error_{false};
+  bool is_pure_{true};
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> internal_allocations_;
+};
+}  // namespace
+
+bool IsPureFunction(const PrimFunc& func, bool assert_on_error) {
+  return PurityChecker::Check(func, assert_on_error);
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 5067d90838..8a3d2d6947 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -21,12 +21,52 @@
  * \file src/tir/ir/function.cc
  * \brief The function data structure.
  */
+#include <tvm/relax/struct_info.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/op.h>
 
 namespace tvm {
 namespace tir {
+namespace {
+relax::StructInfo InferStructInfo(const PrimFunc& prim_func) {
+  Array<relax::StructInfo> params;
+  for (const auto& param : prim_func->params) {
+    relax::StructInfo param_sinfo = [&]() -> relax::StructInfo {
+      if (auto opt_buf = prim_func->buffer_map.Get(param)) {
+        auto buf = opt_buf.value();
+        relax::ShapeExpr shape(
+            buf->shape.Map([](PrimExpr dim) { return cast(DataType::Int(64), 
dim); }));
+        return relax::TensorStructInfo(shape, buf->dtype);
+      }
+
+      if (auto prim_type = param->type_annotation.as<PrimTypeNode>();
+          prim_type && prim_type->dtype.is_handle()) {
+        return relax::ObjectStructInfo();
+      }
+
+      return relax::PrimStructInfo(param->dtype);
+    }();
+    params.push_back(param_sinfo);
+  }
+
+  relax::StructInfo ret = [&]() -> relax::StructInfo {
+    if (const auto* prim = prim_func->ret_type.as<PrimTypeNode>()) {
+      return relax::PrimStructInfo(prim->dtype);
+    } else if (IsVoidType(prim_func->ret_type)) {
+      return relax::TupleStructInfo(Array<relax::StructInfo>{});
+    } else {
+      return relax::ObjectStructInfo();
+    }
+  }();
+
+  bool purity = prim_func->body.defined() ? IsPureFunction(prim_func) : false;
+
+  return relax::FuncStructInfo(params, ret, purity);
+}
+}  // namespace
+
 // Get the function type of a PrimFunc
 PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
                    Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span 
span) {
@@ -42,8 +82,11 @@ PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type 
ret_type,
   n->buffer_map = std::move(buffer_map);
   n->attrs = std::move(attrs);
   n->checked_type_ = n->func_type_annotation();
+  n->struct_info_ = relax::FuncStructInfo::OpaqueFunc();
   n->span = std::move(span);
   data_ = std::move(n);
+
+  (*this)->struct_info_ = InferStructInfo(*this);
 }
 
 FuncType PrimFuncNode::func_type_annotation() const {
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 8095b3141f..924ef9a0cd 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -105,14 +105,10 @@ class PrimFuncSpecializer : public StmtExprMutator {
     Stmt body = specializer(f->body);
 
     if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
-      PrimFuncNode* f_ptr = f.CopyOnWrite();
-      f_ptr->params = std::move(params);
-      f_ptr->buffer_map = std::move(buffer_map);
-      f_ptr->body = std::move(body);
-      f_ptr->struct_info_ = NullOpt;
-      f_ptr->checked_type_ = Type(nullptr);
+      return PrimFunc(params, body, f->ret_type, buffer_map, f->attrs, 
f->span);
+    } else {
+      return f;
     }
-    return f;
   }
 
  private:
diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc
index 8a122f8922..28d1100f6b 100644
--- a/src/tir/transforms/renew_defs.cc
+++ b/src/tir/transforms/renew_defs.cc
@@ -76,11 +76,7 @@ class RenewDefMutator : public StmtExprMutator {
     // Visit body
     Stmt body = generator(func->body);
     // Recreate function
-    auto n = make_object<PrimFuncNode>(*func.get());
-    n->params = std::move(params);
-    n->buffer_map = std::move(buffer_map);
-    n->body = std::move(body);
-    return PrimFunc(n);
+    return PrimFunc(params, body, func->ret_type, buffer_map, func->attrs, 
func->span);
   }
 
  private:
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index b76b95646a..7deddfd28e 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -22,6 +22,7 @@ from tvm import tir
 from tvm.script import relax as R
 from tvm.script import ir as I
 from tvm.script import tir as T
+from tvm.script import ir as I
 
 m = tir.Var("m", "int64")
 n = tir.Var("n", "int64")
@@ -656,5 +657,50 @@ def test_well_formed_function_referencing_global_var():
     assert rx.analysis.well_formed(Module["subroutine"])
 
 
+def test_pass_dltensor_arg_to_tir():
+    """Relax may pass R.Tensor as DLTensor
+
+    In TIR, a `DLTensor*` argument with unknown shape and dtype is
+    represented as a `tir.Var` with
+    `tvm::PrimType(DataType::Handle())`, and with no entry in the
+    `PrimFuncNode::buffer_map`.  In Relax, this is represented as
+    `R.Tensor`.  Calls from Relax to TIR that pass a tensor of unknown
+    rank/shape are well-formed.
+
+    In the test case below, a TIR function accepts an arbitrary
+    `R.Tensor`, and returns a boolean value based on inspection of the
+    runtime datatype.
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor) -> R.Prim("bool"):
+            return Module.is_bfloat16_dtype(A)
+
+        @T.prim_func(private=True)
+        def is_bfloat16_dtype(tensor: T.handle) -> T.bool:
+            T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True})
+
+            # From #include <tvm/tir/builtin.h>
+            kArrTypeCode = T.meta_var(5)
+            kArrTypeBits = T.meta_var(6)
+            kArrTypeLanes = T.meta_var(7)
+
+            # From #include <dlpack/dlpack.h>
+            kDLBfloat = T.meta_var(4)
+
+            type_code = T.tvm_struct_get(tensor, 0, kArrTypeCode, 
dtype="uint8")
+            type_bits = T.tvm_struct_get(tensor, 0, kArrTypeBits, 
dtype="uint8")
+            type_lanes = T.tvm_struct_get(tensor, 0, kArrTypeLanes, 
dtype="uint16")
+
+            is_bfloat16: T.bool = (
+                (type_code == kDLBfloat) and (type_bits == 16) and (type_lanes 
== 1)
+            )
+            return is_bfloat16
+
+    assert rx.analysis.well_formed(Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py 
b/tests/python/relax/test_backend_transform_shape_lower.py
index 31eb4b26be..fccf3a5f8a 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -452,6 +452,90 @@ def test_return_match_check():
     assert_structural_equal(after, expected)
 
 
+def test_return_match_check_with_new_expr():
+    """Like test_return_match_check, but requires a computation
+
+    When return body is not same as ret_struct_info, a runtime match
+    check is required.  This match check may require a symbolic
+    expression to be computed.
+    """
+    MS = MatchShapeCode
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], 
"float32"):
+            R.func_attr({"relax.force_pure": True})
+            out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object)
+            return out
+
+    # slot assignment:
+    sindex = {
+        "n": 0,
+        "n * n": 1,
+    }
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], 
"float32"):
+            R.func_attr({"relax.force_pure": True})
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [R.prim_value(2)],
+                sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", 
sinfo_args=[R.Tuple()]
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                x,
+                shape_heap,
+                2,
+                MS.STORE_TO_HEAP,
+                sindex["n"],
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["n"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+
+            _ = Expected.shape_func(shape_heap)
+
+            out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object)
+            _ = R.call_packed(
+                "vm.builtin.check_tensor_info",
+                out,
+                1,
+                R.dtype("float32"),
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                out,
+                shape_heap,
+                1,
+                MS.ASSERT_EQUAL_TO_LOAD,
+                sindex["n * n"],
+                "",
+                sinfo_args=[R.Tuple()],
+            )
+            return out
+
+        @T.prim_func(private=True)
+        def shape_func(H: T.Buffer(T.int64(2), "int64")):
+            # generated compute function
+            T.func_attr({"tir.is_host_func": 1})
+            H[T.int64(sindex["n * n"])] = H[T.int64(sindex["n"])] * 
H[T.int64(sindex["n"])]
+
+    before = Before
+    expected = Expected
+    after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+    assert_structural_equal(after, expected)
+
+
 def test_symbolic_shape_multiple_function():
     MS = MatchShapeCode
     MK = MakeShapeCode
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index a278b09167..41618a32cb 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -19,6 +19,8 @@ import sys
 import tempfile
 
 import numpy as np
+import pytest
+
 import tvm
 import tvm.testing
 from tvm import relax
@@ -35,13 +37,18 @@ class InputModule:
         return y, y_sorted
 
 
-def run_cpu(mod, func_name, *input):
+def run_cpu(mod, func_name, *args):
+    if isinstance(mod, relax.Function):
+        func = mod
+        args = [func_name, *args]
+        func_name = func.attrs["global_symbol"]
+        mod = tvm.IRModule.from_expr(func)
+
     target = tvm.target.Target("llvm")
     ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
-    vm.set_input(func_name, *input)
-    vm.invoke_stateful(func_name)
-    return vm.get_outputs(func_name)
+
+    return vm[func_name](*args)
 
 
 def test_unique():
@@ -88,67 +95,108 @@ def test_print():
         sys.stdout = stdout
 
 
[email protected]_module
-class AssertOpTest:
+def test_assert_passes():
     @R.function(pure=False)
-    def passes(x: R.Tensor((), "int32")):
-        p1 = R.assert_op(relax.const(True))
+    def func(x: R.Tensor((), "int32")):
+        _ = R.assert_op(relax.const(True))
         return x
 
+    run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_passes_with_format_args():
     @R.function(pure=False)
-    def pass_with_args(x: R.Tensor((), "int32")):
-        p1 = R.assert_op(relax.const(True), x, format="You won't see me")
+    def func(x: R.Tensor((), "int32")):
+        _ = R.assert_op(relax.const(True), x, format="You won't see me")
         return x
 
+    run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_fails():
+    @R.function(pure=False)
+    def func(x: R.Tensor((), "int32")):
+        _ = R.assert_op(relax.const(False))
+        return x
+
+    with pytest.raises(AssertionError, match="Assertion Failed"):
+        run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_fails_with_message():
     @R.function(pure=False)
-    def simple_fail(x: R.Tensor((), "int32")):
-        p1 = R.assert_op(relax.const(False))
+    def func(x: R.Tensor((), "int32")):
+        _ = R.assert_op(relax.const(False), format="I failed...")
         return x
 
+    with pytest.raises(AssertionError, match="I failed..."):
+        run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_fails_with_args():
     @R.function(pure=False)
-    def fail_with_message(x: R.Tensor((), "int32")):
-        p1 = R.assert_op(relax.const(False), format="I failed...")
+    def func(x: R.Tensor((), "int32")):
+        _ = R.assert_op(relax.const(False), [x, x])
         return x
 
+    with pytest.raises(AssertionError, match="5, 5"):
+        run_cpu(func, tvm.nd.array(np.array(5).astype("int32")))
+
+
+def test_assert_fails_with_formatted_args():
     @R.function(pure=False)
-    def fail_with_args(x: R.Tensor((), "int32")):
-        # no format
-        p1 = R.assert_op(relax.const(False), [x, x])
+    def func(x: R.Tensor((), "int32")):
+        _ = R.assert_op(relax.const(False), x, format="Number: {}")
         return x
 
+    with pytest.raises(AssertionError, match="Number: 6"):
+        run_cpu(func, tvm.nd.array(np.array(6).astype("int32")))
+
+
+def test_assert_on_argument_passes():
     @R.function(pure=False)
-    def fail_with_formatted_message(x: R.Tensor((), "int32")):
-        p1 = R.assert_op(relax.const(False), x, format="Number: {}")
+    def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
+        _ = R.assert_op(condition)
         return x
 
+    condition = tvm.nd.array(np.array(True))
+    x = tvm.nd.array(np.array(5).astype("int32"))
+    run_cpu(func, condition, x)
 
-def test_assert_op():
-    def check_assertion_error(func_name, func_arg, expected_message):
-        passed = False
-        try:
-            run_cpu(AssertOpTest, func_name, func_arg)
-            passed = True
-        except TVMError as e:
-            # TVM will print out a TVMError that will contain the
-            # generated error at the bottom of a stack trace
-            assert "AssertionError" in e.args[0]
-            assert expected_message in e.args[0]
-        except AssertionError:
-            return
-        assert not passed
-
-    run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32")))
-    run_cpu(AssertOpTest, "pass_with_args", 
tvm.nd.array(np.array(2).astype("int32")))
-    check_assertion_error(
-        "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion 
Failed"
-    )
-    check_assertion_error(
-        "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I 
failed..."
-    )
-    check_assertion_error("fail_with_args", 
tvm.nd.array(np.array(5).astype("int32")), "5, 5")
-    check_assertion_error(
-        "fail_with_formatted_message", 
tvm.nd.array(np.array(6).astype("int32")), "Number: 6"
-    )
+
+def test_assert_on_argument_fails():
+    @R.function(pure=False)
+    def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
+        _ = R.assert_op(condition)
+        return x
+
+    condition = tvm.nd.array(np.array(False))
+    x = tvm.nd.array(np.array(5).astype("int32"))
+    with pytest.raises(AssertionError):
+        run_cpu(func, condition, x)
+
+
+def test_assert_on_symbolic_var_passes():
+    @R.function(pure=False)
+    def func(x: R.Tensor(["N"], "int32")):
+        N = T.int64()
+        _ = R.assert_op(R.prim_value(N % 8 == 0))
+        return x
+
+    x = tvm.nd.array(np.arange(8, dtype="int32"))
+    run_cpu(func, x)
+
+
+def test_assert_on_symbolic_var_fails():
+    @R.function(pure=False)
+    def func(x: R.Tensor(["N"], "int32")):
+        N = T.int64()
+        _ = R.assert_op(R.prim_value(N % 8 == 0))
+        return x
+
+    x = tvm.nd.array(np.arange(10, dtype="int32"))
+    with pytest.raises(AssertionError):
+        run_cpu(func, x)
 
 
 @tvm.script.ir_module
@@ -370,5 +418,60 @@ def test_op_to_vdevice():
     assert (copy_found.numpy() == arr).all()
 
 
+def test_scalar_tensor_as_branch_condition():
+    """The condition of a branch may be a scalar tensor"""
+
+    @R.function
+    def func(condition: R.Tensor((), "bool")):
+        if condition:
+            out = R.prim_value(5)
+        else:
+            out = R.prim_value(10)
+        return out
+
+    res = run_cpu(func, tvm.nd.array(np.array(True)))
+    assert res == 5
+
+    res = run_cpu(func, tvm.nd.array(np.array(False)))
+    assert res == 10
+
+
+def test_prim_value_as_branch_condition():
+    """The condition may be a PrimValue"""
+
+    @R.function
+    def func(condition: R.Prim("bool")):
+        if condition:
+            out = R.prim_value(5)
+        else:
+            out = R.prim_value(10)
+        return out
+
+    res = run_cpu(func, True)
+    assert res == 5
+
+    res = run_cpu(func, False)
+    assert res == 10
+
+
+def test_computed_prim_value_as_branch_condition():
+    """The R.Prim condition may be computed within the function"""
+
+    @R.function
+    def func(x: R.Tensor(["N"], "int64")):
+        N = T.int64()
+        if R.prim_value(N % 16 == 0):
+            out = R.prim_value(5)
+        else:
+            out = R.prim_value(10)
+        return out
+
+    res = run_cpu(func, tvm.nd.array(np.arange(16)))
+    assert res == 5
+
+    res = run_cpu(func, tvm.nd.array(np.arange(20)))
+    assert res == 10
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
index 9ab2ffc605..7fbf9a2da1 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -343,14 +343,18 @@ def test_call_tir_inplace_multiple_args():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func
-        def copy(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")):
+        def copy(
+            A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: 
T.Buffer((2, 3), "int32")
+        ):
+            # copies the contents of C into A and B
             T.func_attr({"tir.noalias": True})
             for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_zeros"):
                     ax0, ax1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(B[ax0, ax1])
-                    T.writes(A[ax0, ax1])
-                    A[ax0, ax1] = B[ax0, ax1]
+                    T.reads(C[ax0, ax1])
+                    T.writes(A[ax0, ax1], B[ax0, ax1])
+                    A[ax0, ax1] = C[ax0, ax1]
+                    B[ax0, ax1] = C[ax0, ax1]
 
         @R.function
         def foo(
diff --git a/tests/python/relax/test_transform_compute_prim_value.py 
b/tests/python/relax/test_transform_compute_prim_value.py
new file mode 100644
index 0000000000..9fee35414d
--- /dev/null
+++ b/tests/python/relax/test_transform_compute_prim_value.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.relax.transform.ComputePrimValue()
+
+
+class TestPrimValueInAssertCondition(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function(pure=False)
+        def main(A: R.Tensor(["N"])):
+            N = T.int64()
+            _ = R.assert_op(N % 16 == 0)
+            return A
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def main(A: R.Tensor(["N"])):
+            N = T.int64()
+            condition: R.Prim("bool") = 
Expected.compute_symbolic_expr(R.prim_value(N))
+            _ = R.assert_op(condition)
+            return A
+
+        @T.prim_func(private=True)
+        def compute_symbolic_expr(N: T.int64) -> T.bool:
+            T.ret(N % 16 == 0)
+
+
+class TestPrimValueInBranchCondition(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function(pure=False)
+        def main(A: R.Tensor(["N"])):
+            N = T.int64()
+            if R.prim_value(N % 16 == 0):
+                out = R.call_packed("fast_vectorized_impl", A, 
sinfo_args=[A.struct_info])
+            else:
+                out = R.call_packed("slow_non_vectorized_impl", A, 
sinfo_args=[A.struct_info])
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def main(A: R.Tensor(["N"])):
+            N = T.int64()
+            condition: R.Prim("bool") = 
Expected.compute_symbolic_expr(R.prim_value(N))
+            if condition:
+                out = R.call_packed("fast_vectorized_impl", A, 
sinfo_args=[A.struct_info])
+            else:
+                out = R.call_packed("slow_non_vectorized_impl", A, 
sinfo_args=[A.struct_info])
+            return out
+
+        @T.prim_func(private=True)
+        def compute_symbolic_expr(N: T.int64) -> T.bool:
+            T.ret(N % 16 == 0)
+
+
+class TestPrimValueInPureFunction(BaseCompare):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> 
R.Prim(value="N*M"):
+            N = T.int64()
+            M = T.int64()
+            out = R.prim_value(N * M)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> 
R.Prim(value="N*M"):
+            N = T.int64()
+            M = T.int64()
+            out = Expected.compute_symbolic_expr(R.prim_value(N), 
R.prim_value(M))
+            return out
+
+        @T.prim_func(private=True)
+        def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64:
+            T.ret(N * M)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 2221cb89eb..c8db26c81b 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1261,6 +1261,149 @@ def test_if_branch_var_scope():
             return w
 
 
+def test_scalar_tensor_as_branch_condition():
+    """Branch condition can be 0-d tensor"""
+
+    @R.function
+    def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")):
+        if cond:
+            out = R.add(x, x)
+        else:
+            out = R.multiply(x, x)
+        return out
+
+    if_else = func.body.blocks[0].bindings[0].value
+    assert isinstance(if_else.cond, relax.Var)
+    tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Tensor([], 
"bool"))
+
+
+def test_prim_value_as_branch_condition():
+    """In addition to scalar tensor, can use R.Prim condition"""
+
+    @R.function
+    def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")):
+        if cond:
+            out = R.add(x, x)
+        else:
+            out = R.multiply(x, x)
+        return out
+
+    if_else = func.body.blocks[0].bindings[0].value
+    assert isinstance(if_else.cond, relax.Var)
+    tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim("bool"))
+
+
+def test_computed_prim_value_as_branch_condition():
+    """The R.Prim condition may be computed within the function"""
+
+    @R.function
+    def func(x: R.Tensor(["N"], "float32")):
+        N = T.int64()
+        if R.prim_value(N % 16 == 0):
+            out = R.call_pure_packed("fast_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        else:
+            out = R.call_pure_packed("slow_non_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        return out
+
+    N = func.params[0].struct_info.shape[0]
+    if_else = func.body.blocks[0].bindings[0].value
+    assert isinstance(if_else.cond, relax.PrimValue)
+    tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value)
+    tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim(value=N % 
16 == 0))
+
+
+def test_tir_expr_as_branch_condition():
+    """Syntactic sugar, wrap PrimExpr as PrimValue"""
+
+    @R.function(private=True)
+    def sugared(x: R.Tensor(["N"], "float32")):
+        N = T.int64()
+        if N % 16 == 0:
+            out = R.call_pure_packed("fast_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        else:
+            out = R.call_pure_packed("slow_non_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        return out
+
+    @R.function(private=True)
+    def unsugared(x: R.Tensor(["N"], "float32")):
+        N = T.int64()
+        if R.prim_value(N % 16 == 0):
+            out = R.call_pure_packed("fast_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        else:
+            out = R.call_pure_packed("slow_non_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        return out
+
+    tvm.ir.assert_structural_equal(unsugared, sugared)
+
+
+def test_scalar_tensor_as_assert_condition():
+    """Branch condition can be 0-d tensor"""
+
+    @R.function(pure=False)
+    def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")):
+        _ = R.assert_op(cond)
+        out = R.add(x, x)
+        return out
+
+    assert_op = func.body.blocks[0].bindings[0].value
+    condition = assert_op.args[0]
+    assert isinstance(condition, relax.Var)
+    tvm.ir.assert_structural_equal(condition.struct_info, R.Tensor([], "bool"))
+
+
+def test_prim_value_as_assert_condition():
+    """In addition to scalar tensor, can use R.Prim condition"""
+
+    @R.function(pure=False)
+    def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")):
+        _ = R.assert_op(cond)
+        out = R.add(x, x)
+        return out
+
+    assert_op = func.body.blocks[0].bindings[0].value
+    condition = assert_op.args[0]
+    assert isinstance(condition, relax.Var)
+    tvm.ir.assert_structural_equal(condition.struct_info, R.Prim("bool"))
+
+
+def test_computed_prim_value_as_assert_condition():
+    """The R.Prim condition may be computed within the function"""
+
+    @R.function(pure=False)
+    def func(x: R.Tensor(["N"], "float32")):
+        N = T.int64()
+        _ = R.assert_op(R.prim_value(N % 16 == 0))
+        out = R.call_packed("fast_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        return out
+
+    N = func.params[0].struct_info.shape[0]
+    assert_op = func.body.blocks[0].bindings[0].value
+    condition = assert_op.args[0]
+    assert isinstance(condition, relax.PrimValue)
+    tvm.ir.assert_structural_equal(N % 16 == 0, condition.value)
+    tvm.ir.assert_structural_equal(condition.struct_info, R.Prim(value=N % 16 
== 0))
+
+
+def test_tir_expr_as_assert_condition():
+    """Syntactic sugar, wrap PrimExpr as PrimValue"""
+
+    @R.function(pure=False, private=True)
+    def sugared(x: R.Tensor(["N"], "float32")):
+        N = T.int64()
+        _ = R.assert_op(N % 16 == 0)
+        out = R.call_packed("fast_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        return out
+
+    @R.function(pure=False, private=True)
+    def unsugared(x: R.Tensor(["N"], "float32")):
+        N = T.int64()
+        _ = R.assert_op(R.prim_value(N % 16 == 0))
+        out = R.call_packed("fast_vectorized_impl", x, 
sinfo_args=[x.struct_info])
+        return out
+
+    tvm.ir.assert_structural_equal(unsugared, sugared)
+
+
 def test_erase_to_well_defined_removes_internal_vars():
     @R.function
     def foo(x: R.Tensor):
@@ -1664,9 +1807,9 @@ def test_context_aware_parsing():
     class Module:
         @T.prim_func
         def add(
-            X: T.Buffer(T.int64(8), "float32"),
+            X: T.Buffer([T.int64(2), T.int64(4)], "float32"),
             Y: T.Buffer((), "float32"),
-            Z: T.Buffer(T.int64(8), "float32"),
+            Z: T.Buffer([T.int64(2), T.int64(4)], "float32"),
         ):
             T.evaluate(0)
 
diff --git a/tests/python/relax/test_vm_codegen_tir.py 
b/tests/python/relax/test_vm_codegen_tir.py
index 21e192955b..9a4817f5fd 100644
--- a/tests/python/relax/test_vm_codegen_tir.py
+++ b/tests/python/relax/test_vm_codegen_tir.py
@@ -72,7 +72,7 @@ def test_tir_call():
             H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
 
         @R.function(pure=False)
-        def foo(x: R.Tensor):
+        def foo(x: R.Tensor([4], "int64")):
             R.func_attr({"global_symbol": "foo"})
             _ = Before.shape_func(x)
             return x
diff --git a/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py 
b/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py
new file mode 100644
index 0000000000..6555ae3f77
--- /dev/null
+++ b/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm.testing
+from tvm.script import tir as T
+
+from tvm.tir.analysis import is_pure_function, assert_pure_function
+
+
+class CheckPureFunction:
+    def test_check_purity(self):
+        assert is_pure_function(self.func)
+
+    def test_assert_purity(self):
+        assert_pure_function(self.func)
+
+
+class CheckImpureFunction:
+    def test_check_purity(self):
+        assert not is_pure_function(self.func)
+
+    def test_assert_purity(self):
+        with pytest.raises(AssertionError):
+            assert_pure_function(self.func)
+
+
+class TestNoOp(CheckPureFunction):
+    @T.prim_func
+    def func():
+        pass
+
+
+class TestReturnValue(CheckPureFunction):
+    @T.prim_func
+    def func() -> T.int32:
+        T.ret(42)
+
+
+class TestComputeValueAndReturn(CheckPureFunction):
+    @T.prim_func
+    def func(N: T.int32, M: T.int32) -> T.int32:
+        T.ret(N * M)
+
+
+class TestReadBufferArgument(CheckPureFunction):
+    @T.prim_func
+    def func(A: T.Buffer(16, "float32")) -> T.float32:
+        T.ret(A[0])
+
+
+class TestWriteToBufferArgument(CheckImpureFunction):
+    @T.prim_func
+    def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
+        for i in range(16):
+            B[i] = A[i]
+
+
+class TestWriteToInternalAllocation(CheckPureFunction):
+    @T.prim_func
+    def func(A: T.Buffer([16, 16], "float32")) -> T.float32:
+        Sum = T.decl_buffer([], "float32")
+        Sum[()] = 0.0
+        for i, j in T.grid(16, 16):
+            Sum[()] = Sum[()] + A[i, j]
+
+        T.ret(Sum[()])
+
+
+class TestCallPureBuiltin(CheckPureFunction):
+    @T.prim_func
+    def func(x: T.float32) -> T.float32:
+        T.ret(T.cos(x))
+
+
+class TestCallPureExtern(CheckPureFunction):
+    @T.prim_func
+    def func():
+        T.call_pure_extern("some_pure_extern_func_name", dtype="void")
+
+
+class TestCallImpureExtern(CheckImpureFunction):
+    @T.prim_func
+    def func():
+        T.call_extern("some_impure_extern_func_name", dtype="void")
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/tir-base/test_tir_specialize.py 
b/tests/python/tir-base/test_tir_specialize.py
index 0422887233..cead775e97 100644
--- a/tests/python/tir-base/test_tir_specialize.py
+++ b/tests/python/tir-base/test_tir_specialize.py
@@ -330,12 +330,11 @@ def test_specialize_buffer_var_to_expr():
     tvm.ir.assert_structural_equal(expected, after)
 
 
-def test_specialization_removes_struct_info():
-    """Reset struct info in specialization
+def test_specialization_updates_struct_info():
+    """Update struct info in specialization
 
-    While a PrimFunc usually doesn't have a `relax.StructInfo`, the
-    field can be populated in some edge cases.  If that PrimFunc is
-    specialized, the struct info should be reset.
+    A PrimFunc may have a `relax.StructInfo`.  If that PrimFunc is
+    specialized, the struct info should be updated.
     """
 
     @T.prim_func(private=True)
@@ -346,24 +345,20 @@ def test_specialization_removes_struct_info():
     def expected() -> T.int32:
         T.ret(50)
 
-    sinfo = tvm.relax.FuncStructInfo(
+    sinfo_before = tvm.relax.FuncStructInfo(
         [tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32")
     )
-    tvm.relax.expr._update_struct_info(before, sinfo)
+    tvm.ir.assert_structural_equal(before.struct_info, sinfo_before)
+
+    sinfo_expected = tvm.relax.FuncStructInfo([], 
tvm.relax.PrimStructInfo("int32"))
+    tvm.ir.assert_structural_equal(expected.struct_info, sinfo_expected)
 
     n = before.params[0]
     param_map = {n: 5}
     after = before.specialize(param_map)
 
-    tvm.ir.assert_structural_equal(expected, after)
-    assert before.struct_info is not None
-
-    # PrimFuncs do not expose the `struct_info_` field.  Checking the
-    # `struct_info` field when it isn't set raises an exception.  This
-    # is the desired behavior, since the struct info before
-    # specialization is no longer valid.
-    with pytest.raises(tvm.TVMError):
-        after.struct_info
+    tvm.ir.assert_structural_equal(after, expected)
+    tvm.ir.assert_structural_equal(after.struct_info, sinfo_expected)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 074603681f..465ffa5cb6 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -340,5 +340,114 @@ def test_thread_binding_dtype():
     assert loop_j.thread_binding.var.dtype == "int32"
 
 
+def test_inferred_sinfo_with_prim_args():
+    """A PrimFunc may have inferred StructInfo"""
+
+    @T.prim_func
+    def func(M: T.int32, N: T.int32) -> T.int32:
+        T.ret(M * N)
+
+    expected = tvm.relax.FuncStructInfo(
+        [
+            tvm.relax.PrimStructInfo("int32"),
+            tvm.relax.PrimStructInfo("int32"),
+        ],
+        tvm.relax.PrimStructInfo("int32"),
+        purity=True,
+    )
+    tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_buffer_args():
+    """PrimFunc buffer arguments are inferred as R.Tensor"""
+
+    @T.prim_func
+    def func(A: T.Buffer([16, 16], "float32"), B: T.Buffer([256], "int32")) -> 
T.float32:
+        T.ret(T.float32(42.0))
+
+    expected = tvm.relax.FuncStructInfo(
+        [
+            tvm.relax.TensorStructInfo([16, 16], "float32"),
+            tvm.relax.TensorStructInfo([256], "int32"),
+        ],
+        tvm.relax.PrimStructInfo("float32"),
+        purity=True,
+    )
+    tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_internal_allocation():
+    """A pure function may still write to internal allocations.
+
+    Whether a function writes to internal allocations is not a visible
+    effect, and does not impact the purity of a function.
+    """
+
+    @T.prim_func
+    def func(A: T.Buffer([16, 16], "float32")) -> T.float32:
+        Sum = T.decl_buffer([], "float32")
+        Sum[()] = 0.0
+        for i, j in T.grid(16, 16):
+            Sum[()] = Sum[()] + A[i, j]
+
+        T.ret(Sum[()])
+
+    expected = tvm.relax.FuncStructInfo(
+        [
+            tvm.relax.TensorStructInfo([16, 16], "float32"),
+        ],
+        tvm.relax.PrimStructInfo("float32"),
+        purity=True,
+    )
+    tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_output_buffer():
+    """A pure function may not write to an argument buffer
+
+    If an argument buffer is written to, the function must be impure.
+    """
+
+    @T.prim_func
+    def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
+        for i in range(16):
+            B[i] = A[i]
+
+    expected = tvm.relax.FuncStructInfo(
+        [
+            tvm.relax.TensorStructInfo([16], "float32"),
+            tvm.relax.TensorStructInfo([16], "float32"),
+        ],
+        tvm.relax.TupleStructInfo([]),
+        purity=False,
+    )
+    tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_dynamic_buffer():
+    """The inferred StructInfo may contain dynamic shapes"""
+
+    @T.prim_func
+    def func(a_handle: T.handle, b_handle: T.handle):
+        M = T.int64()
+        N = T.int64()
+        A = T.match_buffer(a_handle, [M, N], "float32")
+        B = T.match_buffer(b_handle, [M * N], "float32")
+        for i, j in T.grid(M, N):
+            B[i * N + j] = A[i, j]
+
+    M = tvm.tir.Var("M", "int64")
+    N = tvm.tir.Var("N", "int64")
+    expected = tvm.relax.FuncStructInfo(
+        [
+            tvm.relax.TensorStructInfo([M, N], "float32"),
+            tvm.relax.TensorStructInfo([M * N], "float32"),
+        ],
+        tvm.relax.TupleStructInfo([]),
+        purity=False,
+    )
+    tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to