This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 622bd150dd [Relax] Handle binary operations between Tensor and 
PrimValue (#16827)
622bd150dd is described below

commit 622bd150dd331780eb41a1c67c65aae802eb9b20
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Apr 18 16:41:59 2024 -0500

    [Relax] Handle binary operations between Tensor and PrimValue (#16827)
    
    * [Relax] Handle binary operations between Tensor and PrimValue
    
    Prior to this commit, binary operations were only defined between two
    tensors.  This commit allows binary operations to apply between a
    tensor and a `relax::PrimValue`.
    
    When inferring the output `StructInfo`, binary operations with a
    `PrimValue` produce the same output as using a 0-d tensor.  When
    legalizing operations containing a `PrimValue`, they are lowered to
    primitive TIR arguments.
    
    * Fix unit tests
    
    * Restore ICHECK for scalar TIR variable
    
    * Fix a few more unit tests
    
    * Remove handling of ObjectStructInfo
    
    * Undo commenting-out of test cases
    
    * Update for improved error messages
    
    * Fix failing unit tests
    
    * Fix unit test
---
 python/tvm/relax/utils.py                          | 130 +++--
 src/relax/op/op_common.h                           | 103 +++-
 src/relax/op/tensor/binary.cc                      | 112 ++++-
 src/script/printer/relax/tir.cc                    |   7 +-
 src/te/operation/create_primfunc.cc                |  15 +-
 tests/python/relax/test_op_binary.py               | 106 +++-
 tests/python/relax/test_op_nn_convolution.py       |   8 +-
 tests/python/relax/test_op_search.py               |   4 +-
 .../relax/test_transform_legalize_ops_binary.py    | 534 ++++++++++++++++++++-
 9 files changed, 887 insertions(+), 132 deletions(-)

diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index a58b65477c..48beeed8da 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -14,13 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 # pylint: disable=invalid-name,too-many-locals
+
 """Utility functions for Relax"""
+
 import functools
 import inspect
+import itertools
+import string
+
 from typing import Tuple as typing_Tuple
 from typing import Any, Callable, List, Dict, Optional, TypeVar
 
+import tvm
 from .. import tir
 from ..tir import PrimExpr
 from ..runtime import String, convert_to_object
@@ -302,9 +309,23 @@ def gen_call_tir_inputs(
         out_sinfo, and tir_vars.
     """
 
-    def _convert_te_arg(
-        te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
-    ) -> typing_Tuple[Any, List[te_Tensor]]:
+    tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
+
+    call_tir_args = []
+    create_primfunc_args = []
+    # extra list of tir expression arguments
+    # that are not covered by Tensor
+    extra_tir_args_list = []
+
+    def _copy_undefined_var(expr: tir.PrimExpr):
+        def _visit_expr(e: tir.PrimExpr):
+            if isinstance(e, tir.Var) and e not in tir_var_map:
+                new_var = tir.Var(e.name, e.dtype)
+                tir_var_map[e] = new_var
+
+        tir.stmt_functor.post_order_visit(expr, _visit_expr)
+
+    def _convert_te_arg(te_args: Any) -> Any:
         """Helper function used to convert Relax expressions to TE tensor.
 
         In the common case, the type of te_args is a Relax expression and is 
converted
@@ -335,23 +356,8 @@ def gen_call_tir_inputs(
             A tuple of the converted te_args, and a list of te tensors for 
each converted
             Relax expression
         """
-        te_args_list = []
-        # extra list of tir expression arguments
-        # that are not covered by Tensor
-        extra_tir_args_list = []
-
-        def _copy_undefined_var(expr: tir.PrimExpr):
-            def _visit_expr(e: tir.PrimExpr):
-                if isinstance(e, tir.Var) and e not in tir_var_map:
-                    new_var = tir.Var(e.name, e.dtype)
-                    tir_var_map[e] = new_var
-
-            tir.stmt_functor.post_order_visit(expr, _visit_expr)
-
-        n_tensor = 0
 
         def _convert_te_arg_helper(arg):
-            nonlocal n_tensor
             if isinstance(arg, Expr):  # type: ignore
                 if isinstance(arg.struct_info, TensorStructInfo):
                     assert isinstance(
@@ -360,21 +366,46 @@ def gen_call_tir_inputs(
                     for shape_value in arg.struct_info.shape.values:
                         _copy_undefined_var(shape_value)
 
-                    name = chr(ord("A") + n_tensor) if n_tensor < 26 else 
f"input{n_tensor}"
-                    arg = te_tensor(arg, tir_var_map, name)
-                    n_tensor += 1
-                    te_args_list.append(arg)
-                    return arg
+                    n_args = len(create_primfunc_args)
+                    if isinstance(arg, tvm.relax.Var):
+                        name = arg.name_hint
+                    elif n_args < len(string.ascii_uppercase):
+                        name = string.ascii_uppercase[n_args]
+                    else:
+                        name = f"tensor_input_{n_args}"
+
+                    te_arg = te_tensor(arg, tir_var_map, name)
+
+                    call_tir_args.append(arg)
+                    create_primfunc_args.append(te_arg)
+
+                    return te_arg
+
                 if isinstance(arg.struct_info, ShapeStructInfo):
                     assert isinstance(
                         arg, ShapeExpr
                     ), "For Expr having ShapeStructInfo, emit_te now only 
supports ShapeExpr"
                     return [_convert_te_arg_helper(val) for val in arg.values]
-                if (
-                    isinstance(arg.struct_info, PrimStructInfo)
-                    and arg.struct_info.value is not None
-                ):
-                    return _convert_te_arg_helper(arg.struct_info.value)
+
+                if isinstance(arg.struct_info, PrimStructInfo):
+                    if arg.struct_info.value is None:
+                        n_args = len(create_primfunc_args)
+                        if isinstance(arg, tvm.relax.Var):
+                            name = arg.name_hint
+                        elif n_args < len(string.ascii_lowercase):
+                            name = string.ascii_lowercase[n_args]
+                        else:
+                            name = f"scalar_input_{n_args}"
+
+                        tir_param = tir.Var(name, arg.struct_info.dtype)
+
+                        call_tir_args.append(arg)
+                        create_primfunc_args.append(tir_param)
+
+                        return tir_param
+                    else:
+                        return _convert_te_arg_helper(arg.struct_info.value)
+
             elif isinstance(arg, (list, Array)):
                 return [_convert_te_arg_helper(x) for x in arg]
             elif isinstance(arg, tuple):
@@ -395,28 +426,36 @@ def gen_call_tir_inputs(
             raise TypeError("not supported type in emit_te: 
{}".format(type(arg)))
 
         new_arg = _convert_te_arg_helper(te_args)
-        return new_arg, te_args_list, extra_tir_args_list
+        return new_arg
 
     def _get_unbound_tir_vars(
         args: List[te_Tensor], extra_tir_args: List[PrimExpr]
     ) -> List[tir.Var]:
         """get unbound TIR vars (i.e TIR vars used in the shape but is not
         itself a dimension of a shape)"""
+
         bound_vars = set()
         used_vars = set()
 
+        def _populate_bound_vars(expr):
+            if isinstance(expr, te_Tensor):
+                for dim in expr.shape:
+                    _populate_bound_vars(dim)
+            elif isinstance(expr, tir.Var):
+                bound_vars.add(expr)
+
         def _populate_used_vars(expr):
-            if isinstance(expr, tir.Var):
-                used_vars.add(expr)
+            if isinstance(expr, te_Tensor):
+                for dim in expr.shape:
+                    _populate_used_vars(dim)
+            elif isinstance(expr, tir.PrimExpr):
+                used_vars.update(tir.analysis.undefined_vars(expr))
 
-        for val in extra_tir_args:
-            tir.stmt_functor.post_order_visit(val, _populate_used_vars)
+        for arg in itertools.chain(args, extra_tir_args):
+            _populate_used_vars(arg)
 
-        for x in args:
-            for s in x.shape:
-                tir.stmt_functor.post_order_visit(s, _populate_used_vars)
-                if isinstance(s, tir.Var):
-                    bound_vars.add(s)
+        for arg in args:
+            _populate_bound_vars(arg)
 
         diff = used_vars - bound_vars
         return list(diff)
@@ -448,21 +487,18 @@ def gen_call_tir_inputs(
 
     primfunc_attrs = kwargs.pop("primfunc_attrs", None)
 
-    tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
-    new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
-    new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs, 
tir_var_map)
-
-    te_args = te_arg_list + te_kwarg_list
+    te_args = _convert_te_arg(args)
+    te_kwargs = _convert_te_arg(kwargs)
 
-    te_out = func(*new_args, **new_kwargs)
+    te_out = func(*te_args, **te_kwargs)
     assert isinstance(te_out, te_Tensor) or (
         isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, 
te_Tensor) for t in te_out)
     ), "only support te.tensor or tuple/list/Array of te.tensor as function 
output"
 
     outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
-    unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + 
tir_kwarg_list)
+    unbound_tir_vars = _get_unbound_tir_vars([*create_primfunc_args, *outs], 
extra_tir_args_list)
 
-    inputs = [*te_args] + outs + unbound_tir_vars
+    inputs = [*create_primfunc_args] + outs + unbound_tir_vars
     tir_func = create_prim_func(inputs, "int64")
 
     if primfunc_attrs:
@@ -470,8 +506,6 @@ def gen_call_tir_inputs(
 
     tir_func = tir_func.without_attr("global_symbol")
 
-    call_tir_args = [x.op.value for x in te_args]
-
     # Invert the TIR variable mapping, to convert the output shape back
     # with old set of variables.
     tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index f5eed7af06..5e19edb47c 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -31,6 +31,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/tir/data_layout.h>
 
+#include <optional>
 #include <tuple>
 #include <utility>
 #include <vector>
@@ -239,52 +240,112 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
                                         const Map<String, Array<String>>& 
desired_layouts,
                                         const VarLayoutMap& var_layout_map);
 
+/*!
+ * \brief Get the element dtype from StructInfo
+ *
+ * \param sinfo The StructInfo to expect
+ * \return The inferred element dtype.
+ * \throw Throw exception if the StructInfo doesn't have an element type.
+ */
+inline std::optional<DataType> GetElementDType(const StructInfo& sinfo) {
+  if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
+    return prim->dtype;
+  } else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+    return tensor->dtype;
+  } else {
+    return std::nullopt;
+    LOG(FATAL) << "TypeError: "
+               << "Only PrimStructInfo and TensorStructInfo "
+               << "have an associated data type.  "
+               << "Cannot determine element type of " << sinfo;
+  }
+}
+
 /*!
  * \brief Infer the output datatype for binary arithmetic operators.
  * \param call The context Call to the operator.
  * \param ctx The error reporting context.
- * \param x1_sinfo The struct info of the first operand
- * \param x2_sinfo The struct info of the second operand
+ * \param lhs_sinfo The struct info of the first operand
+ * \param rhs_sinfo The struct info of the second operand
  * \return The inferred output dtype.
  * \throw Throw exception if the dtype of two input TensorStructInfo don’t 
match
  */
 inline DataType InferBinaryArithOpOutDtype(const Call& call, const 
BlockBuilder& ctx,
-                                           const TensorStructInfo& x1_sinfo,
-                                           const TensorStructInfo& x2_sinfo) {
-  if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) {
+                                           const StructInfo& lhs_sinfo,
+                                           const StructInfo& rhs_sinfo) {
+  auto opt_lhs_dtype = GetElementDType(lhs_sinfo);
+  if (!opt_lhs_dtype) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "TypeError: "
+                     << "Binary operators must have the same datatype for both 
operands.  "
+                     << "However, " << call << " has argument " << 
call->args[0]
+                     << " on the LHS, with struct info " << lhs_sinfo << ".   
This is of type "
+                     << lhs_sinfo->GetTypeKey() << ", which does not have a 
datatype.");
+  }
+  auto lhs_dtype = opt_lhs_dtype.value();
+
+  auto opt_rhs_dtype = GetElementDType(rhs_sinfo);
+  if (!opt_rhs_dtype) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "TypeError: "
+                     << "Binary operators must have the same datatype for both 
operands.  "
+                     << "However, " << call << " has argument " << 
call->args[1]
+                     << " on the RHS, with struct info " << rhs_sinfo << ".   
This is of type "
+                     << rhs_sinfo->GetTypeKey() << ", which does not have a 
datatype.");
+  }
+  auto rhs_dtype = opt_rhs_dtype.value();
+
+  if (lhs_dtype.is_void() || rhs_dtype.is_void()) {
     return DataType::Void();
-  } else if (x1_sinfo->dtype != x2_sinfo->dtype) {
+  } else if (lhs_dtype != rhs_dtype) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "Data types " << x1_sinfo->dtype << " and " << 
x2_sinfo->dtype
-                     << " must be equal for binary operators");
+                     << "TypeError: "
+                     << "Binary operators must have the same datatype for both 
operands.  "
+                     << "However, " << call << " uses datatype " << lhs_dtype
+                     << " on the LHS (StructInfo of " << lhs_sinfo << "), and 
datatype "
+                     << rhs_dtype << " on the RHS (StructInfo of " << 
rhs_sinfo << ").");
   }
-  return x1_sinfo->dtype;
+  return lhs_dtype;
 }
 
 /*!
  * \brief Infer the output virtual device for binary arithmetic operators.
  * \param call The context Call to the operator.
  * \param ctx The error reporting context.
- * \param x1_sinfo The struct info of the first operand
- * \param x2_sinfo The struct info of the second operand
+ * \param lhs_sinfo The struct info of the first operand
+ * \param rhs_sinfo The struct info of the second operand
  * \return The inferred output vdevice.
  * \throw Throw exception if the vdevice of two input TensorStructInfo don’t 
match
  */
 inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const 
BlockBuilder& ctx,
-                                                      const TensorStructInfo& 
x1_sinfo,
-                                                      const TensorStructInfo& 
x2_sinfo) {
-  if (!x1_sinfo->vdevice.defined() || 
!x1_sinfo->vdevice.value()->target.defined()) {
-    return x2_sinfo->vdevice;
+                                                      const StructInfo& 
lhs_sinfo,
+                                                      const StructInfo& 
rhs_sinfo) {
+  auto get_vdevice = [&](const StructInfo& sinfo) -> Optional<VDevice> {
+    if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+      return tensor->vdevice;
+    } else {
+      return NullOpt;
+    }
+  };
+
+  auto lhs_vdevice = get_vdevice(lhs_sinfo);
+  auto rhs_vdevice = get_vdevice(rhs_sinfo);
+
+  if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) {
+    return rhs_vdevice;
   }
-  if (!x2_sinfo->vdevice.defined() || 
!x2_sinfo->vdevice.value()->target.defined()) {
-    return x1_sinfo->vdevice;
+  if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) {
+    return lhs_vdevice;
   }
-  if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
+  if (lhs_vdevice.value() != rhs_vdevice.value()) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "VDevice " << x1_sinfo->vdevice.value() << " and "
-                     << x2_sinfo->vdevice.value() << " must be equal for 
binary operators");
+                     << "TypeErorr: "
+                     << "Binary operators with Tensor arguments "
+                     << "must have the same VDevice for both operands.  "
+                     << "However, " << call << " has a LHS on VDevice " << 
lhs_vdevice
+                     << " and a RHS on VDevice " << rhs_vdevice);
   }
-  return x1_sinfo->vdevice;
+  return lhs_vdevice;
 }
 
 /*!
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index f1427156e0..afc0fb7303 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -32,43 +32,103 @@ namespace relax {
 template <typename FType>
 StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx,
                                     FType f_compute_out_dtype) {
-  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
-  TensorStructInfo x1_sinfo = input_sinfo[0];
-  TensorStructInfo x2_sinfo = input_sinfo[1];
+  Op op = Downcast<Op>(call->op);
+  size_t n_input = op->arguments.size();
+  if (call->args.size() != n_input) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << call->op << " op should have " << n_input << " 
arguments");
+  }
+
+  auto lhs_sinfo = GetStructInfo(call->args[0]);
+  auto rhs_sinfo = GetStructInfo(call->args[1]);
+
+  CHECK(lhs_sinfo.as<PrimStructInfoNode>() || 
lhs_sinfo.as<TensorStructInfoNode>())
+      << "TypeError: "
+      << "Arguments to binary operators must be either R.Tensor or R.Prim 
types, "
+      << "but expression " << call << " has LHS " << call->args[0] << ", which 
has StructInfo "
+      << lhs_sinfo;
+  CHECK(rhs_sinfo.as<PrimStructInfoNode>() || 
rhs_sinfo.as<TensorStructInfoNode>())
+      << "TypeError: "
+      << "Arguments to binary operators must be either R.Tensor or R.Prim 
types, "
+      << "but expression " << call << " has RHS " << call->args[1] << ", which 
has StructInfo "
+      << rhs_sinfo;
 
   // DateType
-  DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo);
+  DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_sinfo, rhs_sinfo);
+
+  if (lhs_sinfo.as<PrimStructInfoNode>() && 
rhs_sinfo.as<PrimStructInfoNode>()) {
+    return PrimStructInfo(output_dtype);
+  }
 
   // VDevice
-  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, 
x1_sinfo, x2_sinfo);
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, 
lhs_sinfo, rhs_sinfo);
+
+  auto get_ndim = [&](const StructInfo& sinfo) -> int {
+    if (sinfo.as<PrimStructInfoNode>()) {
+      return 1;
+    } else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+      return tensor->ndim;
+    } else {
+      return kUnknownNDim;
+    }
+  };
 
   // ndims
-  int output_ndim;
-  if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
-    output_ndim = kUnknownNDim;
-  } else {
-    output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim);
-  }
+  int output_ndim = [&]() {
+    int lhs_ndim = get_ndim(lhs_sinfo);
+    int rhs_ndim = get_ndim(rhs_sinfo);
+    if (lhs_ndim == kUnknownNDim || rhs_ndim == kUnknownNDim) {
+      return kUnknownNDim;
+    } else {
+      return std::max(lhs_ndim, rhs_ndim);
+    }
+  }();
 
-  const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
-  const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
-  // Shapes and ndims
-  if (x1_shape && x2_shape) {
-    // If all inputs have shapes, directly infer shapes
-    Optional<Array<PrimExpr>> output_shape =
-        InferBinaryBroadcastShape(call, ctx, x1_shape->values, 
x2_shape->values);
-    if (!output_shape.defined()) {
-      return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice);
+  // Shapes
 
+  auto get_shape = [](const StructInfo& sinfo) -> Optional<Array<PrimExpr>> {
+    if (sinfo.as<PrimStructInfoNode>()) {
+      return Array<PrimExpr>{IntImm(DataType::Int(64), 1)};
+    } else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+      return tensor->GetShape();
     } else {
+      return NullOpt;
+    }
+  };
+
+  // If both inputs have a known shape, directly infer the shape of
+  // the output.
+  auto lhs_shape = get_shape(lhs_sinfo);
+  auto rhs_shape = get_shape(rhs_sinfo);
+  if (lhs_shape && rhs_shape) {
+    Optional<Array<PrimExpr>> output_shape =
+        InferBinaryBroadcastShape(call, ctx, lhs_shape.value(), 
rhs_shape.value());
+    if (output_shape.defined()) {
       ICHECK_EQ(static_cast<int>(output_shape.value().size()), output_ndim);
       return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, 
vdevice);
     }
-  } else if (x1_sinfo->shape.defined() && 
x1_sinfo->shape.same_as(x2_sinfo->shape)) {
-    return TensorStructInfo(x1_sinfo->shape.value(), output_dtype, vdevice);
-  } else {
-    return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice);
   }
+
+  auto get_shape_expr = [](const StructInfo& sinfo) -> Optional<Expr> {
+    if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+      return tensor->shape;
+    } else {
+      return NullOpt;
+    }
+  };
+
+  // If the input shape is unknown, but both inputs have the same
+  // `ShapeStructInfo`variable for their shape, then propagate that
+  // variable to the output.
+  auto lhs_shape_expr = get_shape_expr(lhs_sinfo);
+  auto rhs_shape_expr = get_shape_expr(rhs_sinfo);
+  if (lhs_shape_expr.defined() && lhs_shape_expr.same_as(rhs_shape_expr)) {
+    return TensorStructInfo(lhs_shape_expr.value(), output_dtype, vdevice);
+  }
+
+  // If neither of those cases holds, then fall back to an unknown
+  // shape with `output_ndim` dimensionality.
+  return TensorStructInfo(output_dtype, output_ndim, vdevice);
 }
 
 StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& 
ctx) {
@@ -78,8 +138,8 @@ StructInfo InferStructInfoBroadcastArith(const Call& call, 
const BlockBuilder& c
 StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& 
ctx) {
   return InferStructInfoBroadcast(
       call, ctx,
-      [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo& 
x1_sinfo,
-         const TensorStructInfo& x2_sinfo) { return DataType::Bool(); });
+      [](const Call& call, const BlockBuilder& ctx, const StructInfo& 
lhs_sinfo,
+         const StructInfo& rhs_sinfo) { return DataType::Bool(); });
 }
 
 InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc
index 7c7752cfe6..1a9c5d0546 100644
--- a/src/script/printer/relax/tir.cc
+++ b/src/script/printer/relax/tir.cc
@@ -41,9 +41,10 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) {
 }
 
 Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
-  ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only 
uses "
-                                                       "scalar integer TIR 
variables, but gets: "
-                                                    << n;
+  ICHECK(n->dtype.is_scalar()) << "TypeError: "
+                               << "Relax only uses scalar TIR variables,"
+                               << "but received TIR variable " << n << " with 
dtype " << n->dtype;
+
   if (!d->IsVarDefined(n)) {
     RelaxFrameNode* f = GetRelaxFrame(d);
     // There should be at least one Relax frame
diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index 0dc8b38701..03de68e326 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -488,7 +488,9 @@ void RewriteStageToBlock(const te::Operation& op, 
CreateFuncInfo* info, Array<St
     ICHECK_EQ(op->num_outputs(), 1);
     const te::Tensor& tensor = op.output(0);
     // Check op is in op list
-    ICHECK(info->IsArg(tensor));
+    ICHECK(info->IsArg(tensor)) << "The operation " << op << " produces tensor 
" << tensor
+                                << ", but this tensor does not appear as a 
function argument.  "
+                                << "The function accepts arguments " << 
info->arg_list;
     // Declare a buffer for any argument tensors without a pre-existing
     // buffer declaration recorded in the tensor2buffer binds map
     if (info->tensor2buffers.count(tensor) == 0) {
@@ -581,17 +583,16 @@ PrimFunc GenerateAndCompletePrimFunc(const 
Array<ObjectRef>& arg_tir_var_list,
                                      const Array<Stmt>& root_stmts, 
CreateFuncInfo* info) {
   Array<Var> parameters;
   Map<Var, Buffer> buffer_map;
-  for (const ObjectRef& x : arg_tir_var_list) {
-    if (auto n = x.as<te::TensorNode>()) {
-      te::Tensor tensor = GetRef<te::Tensor>(n);
+  for (const ObjectRef& arg : arg_tir_var_list) {
+    if (auto opt_tensor = arg.as<te::Tensor>()) {
+      te::Tensor tensor = opt_tensor.value();
       Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
       parameters.push_back(arg);
       auto it = info->tensor2buffers.find(tensor);
       ICHECK(it != info->tensor2buffers.end());
       buffer_map.Set(arg, it->second);
-    } else if (auto n = x.as<tir::VarNode>()) {
-      tir::Var var = GetRef<tir::Var>(n);
-      parameters.push_back(var);
+    } else if (auto var = arg.as<tir::Var>()) {
+      parameters.push_back(var.value());
     }
   }
   PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters),
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index a0ec08f0ab..85842f1578 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -59,15 +59,15 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
 
 
-(binary_arith_op,) = tvm.testing.parameters(
-    (relax.op.add,),
-    (relax.op.divide,),
-    (relax.op.floor_divide,),
-    (relax.op.multiply,),
-    (relax.op.power,),
-    (relax.op.subtract,),
-    (relax.op.maximum,),
-    (relax.op.minimum,),
+(binary_arith_op, tir_arith_op) = tvm.testing.parameters(
+    (relax.op.add, tir.Add),
+    (relax.op.divide, tir.Div),
+    (relax.op.floor_divide, tir.FloorDiv),
+    (relax.op.multiply, tir.Mul),
+    (relax.op.power, tir.pow),
+    (relax.op.subtract, tir.Sub),
+    (relax.op.maximum, tir.Max),
+    (relax.op.minimum, tir.Min),
 )
 
 
@@ -115,13 +115,47 @@ def test_binary_arith_infer_struct_info(binary_arith_op: 
Callable):
     )
 
 
-(binary_cmp_op,) = tvm.testing.parameters(
-    (relax.op.equal,),
-    (relax.op.greater,),
-    (relax.op.greater_equal,),
-    (relax.op.less,),
-    (relax.op.less_equal,),
-    (relax.op.not_equal,),
+def 
test_infer_struct_info_binary_arith_prim_value_with_tensor(binary_arith_op: 
Callable):
+    bb = relax.BlockBuilder()
+
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y = relax.Var("y", R.Prim("float32"))
+
+    _check_inference(bb, binary_arith_op(x, y), relax.TensorStructInfo((2, 3), 
"float32"))
+
+
+def 
test_infer_struct_info_binary_arith_prim_value_with_prim_value(binary_arith_op: 
Callable):
+    bb = relax.BlockBuilder()
+
+    x = relax.Var("x", R.Prim("float32"))
+    y = relax.Var("y", R.Prim("float32"))
+
+    _check_inference(bb, binary_arith_op(x, y), 
relax.PrimStructInfo("float32"))
+
+
[email protected](reason="Not yet implemented")
+def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value(
+    binary_arith_op: Callable, tir_arith_op
+):
+    bb = relax.BlockBuilder()
+
+    tir_x = tir.Var("tir_x", "float32")
+    tir_y = tir.Var("tir_y", "float32")
+
+    x = relax.Var("x", R.Prim(value=tir_x))
+    y = relax.Var("y", R.Prim(value=tir_y))
+
+    _check_inference(bb, binary_arith_op(x, y), 
relax.PrimStructInfo(value=tir_x + tir_y))
+    _check_inference(bb, binary_arith_op(y, x), 
relax.PrimStructInfo(value=tir_y + tir_x))
+
+
+(binary_cmp_op, tir_cmp_op) = tvm.testing.parameters(
+    (relax.op.equal, tir.EQ),
+    (relax.op.greater, tir.GT),
+    (relax.op.greater_equal, tir.GE),
+    (relax.op.less, tir.LT),
+    (relax.op.less_equal, tir.LE),
+    (relax.op.not_equal, tir.NE),
 )
 
 
@@ -141,6 +175,38 @@ def test_binary_cmp_infer_struct_info(binary_cmp_op: 
Callable):
     _check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3), 
"bool", vdev0))
 
 
+def test_infer_struct_info_binary_cmp_prim_value_to_tensor(binary_cmp_op: 
Callable):
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y = relax.Var("y", R.Prim("float32"))
+    _check_inference(bb, binary_cmp_op(x, y), relax.TensorStructInfo((2, 3), 
"bool"))
+    _check_inference(bb, binary_cmp_op(y, x), relax.TensorStructInfo((2, 3), 
"bool"))
+
+
+def test_infer_struct_info_binary_cmp_prim_value_to_prim_value(binary_cmp_op: 
Callable):
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Prim("float32"))
+    y = relax.Var("y", R.Prim("float32"))
+    _check_inference(bb, binary_cmp_op(x, y), relax.PrimStructInfo("bool"))
+    _check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo("bool"))
+
+
[email protected](reason="Not yet implemented")
+def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value(
+    binary_cmp_op: Callable, tir_cmp_op
+):
+    bb = relax.BlockBuilder()
+
+    tir_x = tir.Var("tir_x", "float32")
+    tir_y = tir.Var("tir_y", "float32")
+
+    x = relax.Var("x", R.Prim(value=tir_x))
+    y = relax.Var("y", R.Prim(value=tir_y))
+
+    _check_inference(bb, binary_cmp_op(x, y), 
relax.PrimStructInfo(value=tir_cmp_op(tir_x, tir_y)))
+    _check_inference(bb, binary_cmp_op(y, x), 
relax.PrimStructInfo(value=tir_cmp_op(tir_y, tir_x)))
+
+
 def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable):
     bb = relax.BlockBuilder()
     m = tir.Var("m", "int64")
@@ -216,7 +282,7 @@ def 
test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable
     bb = relax.BlockBuilder()
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
     y = relax.Var("y", R.Tensor((2, 3), "int32"))
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(binary_arith_op(x, y))
 
 
@@ -224,7 +290,7 @@ def 
test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callab
     bb = relax.BlockBuilder()
     x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm")))
     y = relax.Var("y", R.Tensor((2, 3), "int32", VDevice("cuda")))
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(binary_arith_op(x, y))
 
 
@@ -245,9 +311,9 @@ def 
test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable):
     x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
     y = relax.Var("y", R.Tensor((2, 3), "float32"))
 
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(binary_arith_op(x0, y))
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(binary_arith_op(x1, y))
 
 
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
index 55e35ee203..588dc9b1b1 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -386,7 +386,7 @@ def test_conv1d_dtype_mismatch():
     x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
     w = relax.Var("w", R.Tensor((4, 3, 3), "int8"))
 
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(relax.op.nn.conv1d(x, w))
 
 
@@ -744,7 +744,7 @@ def test_conv1d_transpose_dtype_mismatch():
     x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
     w = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
 
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(relax.op.nn.conv1d_transpose(x, w))
 
 
@@ -1141,7 +1141,7 @@ def test_conv2d_dtype_mismatch():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8"))
 
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(relax.op.nn.conv2d(x, w))
 
 
@@ -1533,7 +1533,7 @@ def test_conv2d_transpose_dtype_mismatch():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     w = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
 
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(relax.op.nn.conv2d_transpose(x, w))
 
 
diff --git a/tests/python/relax/test_op_search.py 
b/tests/python/relax/test_op_search.py
index 21f022d9eb..e67ef442f9 100644
--- a/tests/python/relax/test_op_search.py
+++ b/tests/python/relax/test_op_search.py
@@ -262,9 +262,9 @@ def test_where_infer_struct_info_dtype_mismatch():
     x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
     y1 = relax.Var("y", R.Tensor((2, 3), "float32"))
 
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(relax.op.where(cond, x0, y0))
-    with pytest.raises(TVMError):
+    with pytest.raises(TypeError):
         bb.normalize(relax.op.where(cond, x1, y1))
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py 
b/tests/python/relax/test_transform_legalize_ops_binary.py
index d71a248b25..7b94057824 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -17,7 +17,7 @@
 
 import tvm
 from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I, relax as R, tir as T
 import tvm.testing
 
 
@@ -164,6 +164,44 @@ def test_add_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_add_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.add(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.add, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def add(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] + rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_divide():
     # fmt: off
     @tvm.script.ir_module
@@ -303,6 +341,44 @@ def test_divide_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_divide_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.divide(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.divide, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def divide(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] / rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_floor_divide():
     # fmt: off
     @tvm.script.ir_module
@@ -442,6 +518,44 @@ def test_floor_divide_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_floordiv_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.floor_divide(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.floor_divide, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def floor_divide(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_floordiv"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = T.floor(lhs[vi, vj, vk] / rhs)
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_multiply():
     # fmt: off
     @tvm.script.ir_module
@@ -519,6 +633,44 @@ def test_multiply_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_multiply_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.multiply(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.multiply, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def multiply(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] * rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_power():
     # fmt: off
     @tvm.script.ir_module
@@ -599,6 +751,44 @@ def test_power_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_power_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.power(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.power, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def power(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_power"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = T.pow(lhs[vi, vj, vk], rhs)
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_subtract():
     # fmt: off
     @tvm.script.ir_module
@@ -676,6 +866,44 @@ def test_subtract_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_subtract_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.subtract(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.subtract, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def subtract(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] - rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 ##################### Binary comparison #####################
 
 
@@ -818,6 +1046,44 @@ def test_equal_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_equal_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.equal(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.equal, (x, y), R.Tensor([64, 32, 16], 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def equal(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] == rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_greater():
     # fmt: off
     @tvm.script.ir_module
@@ -957,6 +1223,44 @@ def test_greater_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_greater_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.greater(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.greater, (x, y), R.Tensor([64, 32, 16], 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def greater(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = rhs < lhs[vi, vj, vk]
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_greater_equal():
     # fmt: off
     @tvm.script.ir_module
@@ -1034,6 +1338,44 @@ def test_greater_equal_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_greater_equal_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.greater_equal(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.greater_equal, (x, y), R.Tensor([64, 32, 16], 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def greater_equal(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = rhs <= lhs[vi, vj, vk]
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_less():
     # fmt: off
     @tvm.script.ir_module
@@ -1111,6 +1453,44 @@ def test_less_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_less_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.less(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.less, (x, y), R.Tensor([64, 32, 16], 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def less(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] < rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_less_equal():
     # fmt: off
     @tvm.script.ir_module
@@ -1250,6 +1630,44 @@ def test_less_equal_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_less_equal_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.less_equal(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.less_equal, (x, y), R.Tensor([64, 32, 16], 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def less_equal(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] <= rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_not_equal():
     # fmt: off
     @tvm.script.ir_module
@@ -1327,6 +1745,44 @@ def test_not_equal_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_not_equal_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.not_equal(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.not_equal, (x, y), R.Tensor([64, 32, 16], 
dtype="bool"))
+            return gv
+
+        @T.prim_func(private=True)
+        def not_equal(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = lhs[vi, vj, vk] != rhs
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_maximum():
     # fmt: off
     @tvm.script.ir_module
@@ -1467,6 +1923,44 @@ def test_maximum_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_max_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.maximum(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.maximum, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def maximum(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = T.max(lhs[vi, vj, vk], rhs)
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 def test_minimum():
     # fmt: off
     @tvm.script.ir_module
@@ -1607,5 +2101,43 @@ def test_minimum_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_min_primvalue():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            gv = R.minimum(x, y)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([64, 32, 16], "float32"),
+            y: R.Prim("float32"),
+        ):
+            cls = Expected
+            gv = R.call_tir(cls.minimum, (x, y), R.Tensor([64, 32, 16], 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def minimum(
+            lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+            rhs: T.float32,
+            output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j, k in T.grid(*lhs.shape):
+                with T.block("T_add"):
+                    vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                    output[vi, vj, vk] = T.min(lhs[vi, vj, vk], rhs)
+
+    After = LegalizeOps()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to