This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 622bd150dd [Relax] Handle binary operations between Tensor and
PrimValue (#16827)
622bd150dd is described below
commit 622bd150dd331780eb41a1c67c65aae802eb9b20
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Apr 18 16:41:59 2024 -0500
[Relax] Handle binary operations between Tensor and PrimValue (#16827)
* [Relax] Handle binary operations between Tensor and PrimValue
Prior to this commit, binary operations were only defined between two
tensors. This commit allows binary operations to apply between a
tensor and a `relax::PrimValue`.
When inferring the output `StructInfo`, binary operations with a
`PrimValue` produce the same output as using a 0-d tensor. When
legalizing operations containing a `PrimValue`, they are lowered to
primitive TIR arguments.
* Fix unit tests
* Restore ICHECK for scalar TIR variable
* Fix a few more unit tests
* Remove handling of ObjectStructInfo
* Undo commenting-out of test cases
* Update for improved error messages
* Fix failing unit tests
* Fix unit test
---
python/tvm/relax/utils.py | 130 +++--
src/relax/op/op_common.h | 103 +++-
src/relax/op/tensor/binary.cc | 112 ++++-
src/script/printer/relax/tir.cc | 7 +-
src/te/operation/create_primfunc.cc | 15 +-
tests/python/relax/test_op_binary.py | 106 +++-
tests/python/relax/test_op_nn_convolution.py | 8 +-
tests/python/relax/test_op_search.py | 4 +-
.../relax/test_transform_legalize_ops_binary.py | 534 ++++++++++++++++++++-
9 files changed, 887 insertions(+), 132 deletions(-)
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index a58b65477c..48beeed8da 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -14,13 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
# pylint: disable=invalid-name,too-many-locals
+
"""Utility functions for Relax"""
+
import functools
import inspect
+import itertools
+import string
+
from typing import Tuple as typing_Tuple
from typing import Any, Callable, List, Dict, Optional, TypeVar
+import tvm
from .. import tir
from ..tir import PrimExpr
from ..runtime import String, convert_to_object
@@ -302,9 +309,23 @@ def gen_call_tir_inputs(
out_sinfo, and tir_vars.
"""
- def _convert_te_arg(
- te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
- ) -> typing_Tuple[Any, List[te_Tensor]]:
+ tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
+
+ call_tir_args = []
+ create_primfunc_args = []
+ # extra list of tir expression arguments
+ # that are not covered by Tensor
+ extra_tir_args_list = []
+
+ def _copy_undefined_var(expr: tir.PrimExpr):
+ def _visit_expr(e: tir.PrimExpr):
+ if isinstance(e, tir.Var) and e not in tir_var_map:
+ new_var = tir.Var(e.name, e.dtype)
+ tir_var_map[e] = new_var
+
+ tir.stmt_functor.post_order_visit(expr, _visit_expr)
+
+ def _convert_te_arg(te_args: Any) -> Any:
"""Helper function used to convert Relax expressions to TE tensor.
In the common case, the type of te_args is a Relax expression and is
converted
@@ -335,23 +356,8 @@ def gen_call_tir_inputs(
A tuple of the converted te_args, and a list of te tensors for
each converted
Relax expression
"""
- te_args_list = []
- # extra list of tir expression arguments
- # that are not covered by Tensor
- extra_tir_args_list = []
-
- def _copy_undefined_var(expr: tir.PrimExpr):
- def _visit_expr(e: tir.PrimExpr):
- if isinstance(e, tir.Var) and e not in tir_var_map:
- new_var = tir.Var(e.name, e.dtype)
- tir_var_map[e] = new_var
-
- tir.stmt_functor.post_order_visit(expr, _visit_expr)
-
- n_tensor = 0
def _convert_te_arg_helper(arg):
- nonlocal n_tensor
if isinstance(arg, Expr): # type: ignore
if isinstance(arg.struct_info, TensorStructInfo):
assert isinstance(
@@ -360,21 +366,46 @@ def gen_call_tir_inputs(
for shape_value in arg.struct_info.shape.values:
_copy_undefined_var(shape_value)
- name = chr(ord("A") + n_tensor) if n_tensor < 26 else
f"input{n_tensor}"
- arg = te_tensor(arg, tir_var_map, name)
- n_tensor += 1
- te_args_list.append(arg)
- return arg
+ n_args = len(create_primfunc_args)
+ if isinstance(arg, tvm.relax.Var):
+ name = arg.name_hint
+ elif n_args < len(string.ascii_uppercase):
+ name = string.ascii_uppercase[n_args]
+ else:
+ name = f"tensor_input_{n_args}"
+
+ te_arg = te_tensor(arg, tir_var_map, name)
+
+ call_tir_args.append(arg)
+ create_primfunc_args.append(te_arg)
+
+ return te_arg
+
if isinstance(arg.struct_info, ShapeStructInfo):
assert isinstance(
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only
supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
- if (
- isinstance(arg.struct_info, PrimStructInfo)
- and arg.struct_info.value is not None
- ):
- return _convert_te_arg_helper(arg.struct_info.value)
+
+ if isinstance(arg.struct_info, PrimStructInfo):
+ if arg.struct_info.value is None:
+ n_args = len(create_primfunc_args)
+ if isinstance(arg, tvm.relax.Var):
+ name = arg.name_hint
+ elif n_args < len(string.ascii_lowercase):
+ name = string.ascii_lowercase[n_args]
+ else:
+ name = f"scalar_input_{n_args}"
+
+ tir_param = tir.Var(name, arg.struct_info.dtype)
+
+ call_tir_args.append(arg)
+ create_primfunc_args.append(tir_param)
+
+ return tir_param
+ else:
+ return _convert_te_arg_helper(arg.struct_info.value)
+
elif isinstance(arg, (list, Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
@@ -395,28 +426,36 @@ def gen_call_tir_inputs(
raise TypeError("not supported type in emit_te:
{}".format(type(arg)))
new_arg = _convert_te_arg_helper(te_args)
- return new_arg, te_args_list, extra_tir_args_list
+ return new_arg
def _get_unbound_tir_vars(
args: List[te_Tensor], extra_tir_args: List[PrimExpr]
) -> List[tir.Var]:
"""get unbound TIR vars (i.e TIR vars used in the shape but is not
itself a dimension of a shape)"""
+
bound_vars = set()
used_vars = set()
+ def _populate_bound_vars(expr):
+ if isinstance(expr, te_Tensor):
+ for dim in expr.shape:
+ _populate_bound_vars(dim)
+ elif isinstance(expr, tir.Var):
+ bound_vars.add(expr)
+
def _populate_used_vars(expr):
- if isinstance(expr, tir.Var):
- used_vars.add(expr)
+ if isinstance(expr, te_Tensor):
+ for dim in expr.shape:
+ _populate_used_vars(dim)
+ elif isinstance(expr, tir.PrimExpr):
+ used_vars.update(tir.analysis.undefined_vars(expr))
- for val in extra_tir_args:
- tir.stmt_functor.post_order_visit(val, _populate_used_vars)
+ for arg in itertools.chain(args, extra_tir_args):
+ _populate_used_vars(arg)
- for x in args:
- for s in x.shape:
- tir.stmt_functor.post_order_visit(s, _populate_used_vars)
- if isinstance(s, tir.Var):
- bound_vars.add(s)
+ for arg in args:
+ _populate_bound_vars(arg)
diff = used_vars - bound_vars
return list(diff)
@@ -448,21 +487,18 @@ def gen_call_tir_inputs(
primfunc_attrs = kwargs.pop("primfunc_attrs", None)
- tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
- new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
- new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs,
tir_var_map)
-
- te_args = te_arg_list + te_kwarg_list
+ te_args = _convert_te_arg(args)
+ te_kwargs = _convert_te_arg(kwargs)
- te_out = func(*new_args, **new_kwargs)
+ te_out = func(*te_args, **te_kwargs)
assert isinstance(te_out, te_Tensor) or (
isinstance(te_out, (tuple, list, Array)) and all(isinstance(t,
te_Tensor) for t in te_out)
), "only support te.tensor or tuple/list/Array of te.tensor as function
output"
outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
- unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list +
tir_kwarg_list)
+ unbound_tir_vars = _get_unbound_tir_vars([*create_primfunc_args, *outs],
extra_tir_args_list)
- inputs = [*te_args] + outs + unbound_tir_vars
+ inputs = [*create_primfunc_args] + outs + unbound_tir_vars
tir_func = create_prim_func(inputs, "int64")
if primfunc_attrs:
@@ -470,8 +506,6 @@ def gen_call_tir_inputs(
tir_func = tir_func.without_attr("global_symbol")
- call_tir_args = [x.op.value for x in te_args]
-
# Invert the TIR variable mapping, to convert the output shape back
# with old set of variables.
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index f5eed7af06..5e19edb47c 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -31,6 +31,7 @@
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>
+#include <optional>
#include <tuple>
#include <utility>
#include <vector>
@@ -239,52 +240,112 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
const Map<String, Array<String>>&
desired_layouts,
const VarLayoutMap& var_layout_map);
+/*!
+ * \brief Get the element dtype from StructInfo
+ *
+ * \param sinfo The StructInfo to expect
+ * \return The inferred element dtype.
+ * \throw Throw exception if the StructInfo doesn't have an element type.
+ */
+inline std::optional<DataType> GetElementDType(const StructInfo& sinfo) {
+ if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
+ return prim->dtype;
+ } else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+ return tensor->dtype;
+ } else {
+ return std::nullopt;
+ LOG(FATAL) << "TypeError: "
+ << "Only PrimStructInfo and TensorStructInfo "
+ << "have an associated data type. "
+ << "Cannot determine element type of " << sinfo;
+ }
+}
+
/*!
* \brief Infer the output datatype for binary arithmetic operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
- * \param x1_sinfo The struct info of the first operand
- * \param x2_sinfo The struct info of the second operand
+ * \param lhs_sinfo The struct info of the first operand
+ * \param rhs_sinfo The struct info of the second operand
* \return The inferred output dtype.
* \throw Throw exception if the dtype of two input TensorStructInfo don’t
match
*/
inline DataType InferBinaryArithOpOutDtype(const Call& call, const
BlockBuilder& ctx,
- const TensorStructInfo& x1_sinfo,
- const TensorStructInfo& x2_sinfo) {
- if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) {
+ const StructInfo& lhs_sinfo,
+ const StructInfo& rhs_sinfo) {
+ auto opt_lhs_dtype = GetElementDType(lhs_sinfo);
+ if (!opt_lhs_dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "TypeError: "
+ << "Binary operators must have the same datatype for both
operands. "
+ << "However, " << call << " has argument " <<
call->args[0]
+ << " on the LHS, with struct info " << lhs_sinfo << ".
This is of type "
+ << lhs_sinfo->GetTypeKey() << ", which does not have a
datatype.");
+ }
+ auto lhs_dtype = opt_lhs_dtype.value();
+
+ auto opt_rhs_dtype = GetElementDType(rhs_sinfo);
+ if (!opt_rhs_dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "TypeError: "
+ << "Binary operators must have the same datatype for both
operands. "
+ << "However, " << call << " has argument " <<
call->args[1]
+ << " on the RHS, with struct info " << rhs_sinfo << ".
This is of type "
+ << rhs_sinfo->GetTypeKey() << ", which does not have a
datatype.");
+ }
+ auto rhs_dtype = opt_rhs_dtype.value();
+
+ if (lhs_dtype.is_void() || rhs_dtype.is_void()) {
return DataType::Void();
- } else if (x1_sinfo->dtype != x2_sinfo->dtype) {
+ } else if (lhs_dtype != rhs_dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "Data types " << x1_sinfo->dtype << " and " <<
x2_sinfo->dtype
- << " must be equal for binary operators");
+ << "TypeError: "
+ << "Binary operators must have the same datatype for both
operands. "
+ << "However, " << call << " uses datatype " << lhs_dtype
+ << " on the LHS (StructInfo of " << lhs_sinfo << "), and
datatype "
+ << rhs_dtype << " on the RHS (StructInfo of " <<
rhs_sinfo << ").");
}
- return x1_sinfo->dtype;
+ return lhs_dtype;
}
/*!
* \brief Infer the output virtual device for binary arithmetic operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
- * \param x1_sinfo The struct info of the first operand
- * \param x2_sinfo The struct info of the second operand
+ * \param lhs_sinfo The struct info of the first operand
+ * \param rhs_sinfo The struct info of the second operand
* \return The inferred output vdevice.
* \throw Throw exception if the vdevice of two input TensorStructInfo don’t
match
*/
inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const
BlockBuilder& ctx,
- const TensorStructInfo&
x1_sinfo,
- const TensorStructInfo&
x2_sinfo) {
- if (!x1_sinfo->vdevice.defined() ||
!x1_sinfo->vdevice.value()->target.defined()) {
- return x2_sinfo->vdevice;
+ const StructInfo&
lhs_sinfo,
+ const StructInfo&
rhs_sinfo) {
+ auto get_vdevice = [&](const StructInfo& sinfo) -> Optional<VDevice> {
+ if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+ return tensor->vdevice;
+ } else {
+ return NullOpt;
+ }
+ };
+
+ auto lhs_vdevice = get_vdevice(lhs_sinfo);
+ auto rhs_vdevice = get_vdevice(rhs_sinfo);
+
+ if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) {
+ return rhs_vdevice;
}
- if (!x2_sinfo->vdevice.defined() ||
!x2_sinfo->vdevice.value()->target.defined()) {
- return x1_sinfo->vdevice;
+ if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) {
+ return lhs_vdevice;
}
- if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
+ if (lhs_vdevice.value() != rhs_vdevice.value()) {
ctx->ReportFatal(Diagnostic::Error(call)
- << "VDevice " << x1_sinfo->vdevice.value() << " and "
- << x2_sinfo->vdevice.value() << " must be equal for
binary operators");
+ << "TypeErorr: "
+ << "Binary operators with Tensor arguments "
+ << "must have the same VDevice for both operands. "
+ << "However, " << call << " has a LHS on VDevice " <<
lhs_vdevice
+ << " and a RHS on VDevice " << rhs_vdevice);
}
- return x1_sinfo->vdevice;
+ return lhs_vdevice;
}
/*!
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index f1427156e0..afc0fb7303 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -32,43 +32,103 @@ namespace relax {
template <typename FType>
StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx,
FType f_compute_out_dtype) {
- Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
- TensorStructInfo x1_sinfo = input_sinfo[0];
- TensorStructInfo x2_sinfo = input_sinfo[1];
+ Op op = Downcast<Op>(call->op);
+ size_t n_input = op->arguments.size();
+ if (call->args.size() != n_input) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << call->op << " op should have " << n_input << "
arguments");
+ }
+
+ auto lhs_sinfo = GetStructInfo(call->args[0]);
+ auto rhs_sinfo = GetStructInfo(call->args[1]);
+
+ CHECK(lhs_sinfo.as<PrimStructInfoNode>() ||
lhs_sinfo.as<TensorStructInfoNode>())
+ << "TypeError: "
+ << "Arguments to binary operators must be either R.Tensor or R.Prim
types, "
+ << "but expression " << call << " has LHS " << call->args[0] << ", which
has StructInfo "
+ << lhs_sinfo;
+ CHECK(rhs_sinfo.as<PrimStructInfoNode>() ||
rhs_sinfo.as<TensorStructInfoNode>())
+ << "TypeError: "
+ << "Arguments to binary operators must be either R.Tensor or R.Prim
types, "
+ << "but expression " << call << " has RHS " << call->args[1] << ", which
has StructInfo "
+ << rhs_sinfo;
// DateType
- DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo);
+ DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_sinfo, rhs_sinfo);
+
+ if (lhs_sinfo.as<PrimStructInfoNode>() &&
rhs_sinfo.as<PrimStructInfoNode>()) {
+ return PrimStructInfo(output_dtype);
+ }
// VDevice
- Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx,
x1_sinfo, x2_sinfo);
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx,
lhs_sinfo, rhs_sinfo);
+
+ auto get_ndim = [&](const StructInfo& sinfo) -> int {
+ if (sinfo.as<PrimStructInfoNode>()) {
+ return 1;
+ } else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+ return tensor->ndim;
+ } else {
+ return kUnknownNDim;
+ }
+ };
// ndims
- int output_ndim;
- if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
- output_ndim = kUnknownNDim;
- } else {
- output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim);
- }
+ int output_ndim = [&]() {
+ int lhs_ndim = get_ndim(lhs_sinfo);
+ int rhs_ndim = get_ndim(rhs_sinfo);
+ if (lhs_ndim == kUnknownNDim || rhs_ndim == kUnknownNDim) {
+ return kUnknownNDim;
+ } else {
+ return std::max(lhs_ndim, rhs_ndim);
+ }
+ }();
- const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
- const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
- // Shapes and ndims
- if (x1_shape && x2_shape) {
- // If all inputs have shapes, directly infer shapes
- Optional<Array<PrimExpr>> output_shape =
- InferBinaryBroadcastShape(call, ctx, x1_shape->values,
x2_shape->values);
- if (!output_shape.defined()) {
- return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice);
+ // Shapes
+ auto get_shape = [](const StructInfo& sinfo) -> Optional<Array<PrimExpr>> {
+ if (sinfo.as<PrimStructInfoNode>()) {
+ return Array<PrimExpr>{IntImm(DataType::Int(64), 1)};
+ } else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+ return tensor->GetShape();
} else {
+ return NullOpt;
+ }
+ };
+
+ // If both inputs have a known shape, directly infer the shape of
+ // the output.
+ auto lhs_shape = get_shape(lhs_sinfo);
+ auto rhs_shape = get_shape(rhs_sinfo);
+ if (lhs_shape && rhs_shape) {
+ Optional<Array<PrimExpr>> output_shape =
+ InferBinaryBroadcastShape(call, ctx, lhs_shape.value(),
rhs_shape.value());
+ if (output_shape.defined()) {
ICHECK_EQ(static_cast<int>(output_shape.value().size()), output_ndim);
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype,
vdevice);
}
- } else if (x1_sinfo->shape.defined() &&
x1_sinfo->shape.same_as(x2_sinfo->shape)) {
- return TensorStructInfo(x1_sinfo->shape.value(), output_dtype, vdevice);
- } else {
- return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice);
}
+
+ auto get_shape_expr = [](const StructInfo& sinfo) -> Optional<Expr> {
+ if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+ return tensor->shape;
+ } else {
+ return NullOpt;
+ }
+ };
+
+ // If the input shape is unknown, but both inputs have the same
+ // `ShapeStructInfo`variable for their shape, then propagate that
+ // variable to the output.
+ auto lhs_shape_expr = get_shape_expr(lhs_sinfo);
+ auto rhs_shape_expr = get_shape_expr(rhs_sinfo);
+ if (lhs_shape_expr.defined() && lhs_shape_expr.same_as(rhs_shape_expr)) {
+ return TensorStructInfo(lhs_shape_expr.value(), output_dtype, vdevice);
+ }
+
+ // If neither of those cases holds, then fall back to an unknown
+ // shape with `output_ndim` dimensionality.
+ return TensorStructInfo(output_dtype, output_ndim, vdevice);
}
StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder&
ctx) {
@@ -78,8 +138,8 @@ StructInfo InferStructInfoBroadcastArith(const Call& call,
const BlockBuilder& c
StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder&
ctx) {
return InferStructInfoBroadcast(
call, ctx,
- [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo&
x1_sinfo,
- const TensorStructInfo& x2_sinfo) { return DataType::Bool(); });
+ [](const Call& call, const BlockBuilder& ctx, const StructInfo&
lhs_sinfo,
+ const StructInfo& rhs_sinfo) { return DataType::Bool(); });
}
InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc
index 7c7752cfe6..1a9c5d0546 100644
--- a/src/script/printer/relax/tir.cc
+++ b/src/script/printer/relax/tir.cc
@@ -41,9 +41,10 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) {
}
Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
- ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only
uses "
- "scalar integer TIR
variables, but gets: "
- << n;
+ ICHECK(n->dtype.is_scalar()) << "TypeError: "
+ << "Relax only uses scalar TIR variables,"
+ << "but received TIR variable " << n << " with
dtype " << n->dtype;
+
if (!d->IsVarDefined(n)) {
RelaxFrameNode* f = GetRelaxFrame(d);
// There should be at least one Relax frame
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 0dc8b38701..03de68e326 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -488,7 +488,9 @@ void RewriteStageToBlock(const te::Operation& op,
CreateFuncInfo* info, Array<St
ICHECK_EQ(op->num_outputs(), 1);
const te::Tensor& tensor = op.output(0);
// Check op is in op list
- ICHECK(info->IsArg(tensor));
+ ICHECK(info->IsArg(tensor)) << "The operation " << op << " produces tensor
" << tensor
+ << ", but this tensor does not appear as a
function argument. "
+ << "The function accepts arguments " <<
info->arg_list;
// Declare a buffer for any argument tensors without a pre-existing
// buffer declaration recorded in the tensor2buffer binds map
if (info->tensor2buffers.count(tensor) == 0) {
@@ -581,17 +583,16 @@ PrimFunc GenerateAndCompletePrimFunc(const
Array<ObjectRef>& arg_tir_var_list,
const Array<Stmt>& root_stmts,
CreateFuncInfo* info) {
Array<Var> parameters;
Map<Var, Buffer> buffer_map;
- for (const ObjectRef& x : arg_tir_var_list) {
- if (auto n = x.as<te::TensorNode>()) {
- te::Tensor tensor = GetRef<te::Tensor>(n);
+ for (const ObjectRef& arg : arg_tir_var_list) {
+ if (auto opt_tensor = arg.as<te::Tensor>()) {
+ te::Tensor tensor = opt_tensor.value();
Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle()));
parameters.push_back(arg);
auto it = info->tensor2buffers.find(tensor);
ICHECK(it != info->tensor2buffers.end());
buffer_map.Set(arg, it->second);
- } else if (auto n = x.as<tir::VarNode>()) {
- tir::Var var = GetRef<tir::Var>(n);
- parameters.push_back(var);
+ } else if (auto var = arg.as<tir::Var>()) {
+ parameters.push_back(var.value());
}
}
PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters),
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index a0ec08f0ab..85842f1578 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -59,15 +59,15 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
-(binary_arith_op,) = tvm.testing.parameters(
- (relax.op.add,),
- (relax.op.divide,),
- (relax.op.floor_divide,),
- (relax.op.multiply,),
- (relax.op.power,),
- (relax.op.subtract,),
- (relax.op.maximum,),
- (relax.op.minimum,),
+(binary_arith_op, tir_arith_op) = tvm.testing.parameters(
+ (relax.op.add, tir.Add),
+ (relax.op.divide, tir.Div),
+ (relax.op.floor_divide, tir.FloorDiv),
+ (relax.op.multiply, tir.Mul),
+ (relax.op.power, tir.pow),
+ (relax.op.subtract, tir.Sub),
+ (relax.op.maximum, tir.Max),
+ (relax.op.minimum, tir.Min),
)
@@ -115,13 +115,47 @@ def test_binary_arith_infer_struct_info(binary_arith_op:
Callable):
)
-(binary_cmp_op,) = tvm.testing.parameters(
- (relax.op.equal,),
- (relax.op.greater,),
- (relax.op.greater_equal,),
- (relax.op.less,),
- (relax.op.less_equal,),
- (relax.op.not_equal,),
+def
test_infer_struct_info_binary_arith_prim_value_with_tensor(binary_arith_op:
Callable):
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Tensor((2, 3), "float32"))
+ y = relax.Var("y", R.Prim("float32"))
+
+ _check_inference(bb, binary_arith_op(x, y), relax.TensorStructInfo((2, 3),
"float32"))
+
+
+def
test_infer_struct_info_binary_arith_prim_value_with_prim_value(binary_arith_op:
Callable):
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Prim("float32"))
+ y = relax.Var("y", R.Prim("float32"))
+
+ _check_inference(bb, binary_arith_op(x, y),
relax.PrimStructInfo("float32"))
+
+
[email protected](reason="Not yet implemented")
+def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value(
+ binary_arith_op: Callable, tir_arith_op
+):
+ bb = relax.BlockBuilder()
+
+ tir_x = tir.Var("tir_x", "float32")
+ tir_y = tir.Var("tir_y", "float32")
+
+ x = relax.Var("x", R.Prim(value=tir_x))
+ y = relax.Var("y", R.Prim(value=tir_y))
+
+ _check_inference(bb, binary_arith_op(x, y),
relax.PrimStructInfo(value=tir_x + tir_y))
+ _check_inference(bb, binary_arith_op(y, x),
relax.PrimStructInfo(value=tir_y + tir_x))
+
+
+(binary_cmp_op, tir_cmp_op) = tvm.testing.parameters(
+ (relax.op.equal, tir.EQ),
+ (relax.op.greater, tir.GT),
+ (relax.op.greater_equal, tir.GE),
+ (relax.op.less, tir.LT),
+ (relax.op.less_equal, tir.LE),
+ (relax.op.not_equal, tir.NE),
)
@@ -141,6 +175,38 @@ def test_binary_cmp_infer_struct_info(binary_cmp_op:
Callable):
_check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3),
"bool", vdev0))
+def test_infer_struct_info_binary_cmp_prim_value_to_tensor(binary_cmp_op:
Callable):
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3), "float32"))
+ y = relax.Var("y", R.Prim("float32"))
+ _check_inference(bb, binary_cmp_op(x, y), relax.TensorStructInfo((2, 3),
"bool"))
+ _check_inference(bb, binary_cmp_op(y, x), relax.TensorStructInfo((2, 3),
"bool"))
+
+
+def test_infer_struct_info_binary_cmp_prim_value_to_prim_value(binary_cmp_op:
Callable):
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Prim("float32"))
+ y = relax.Var("y", R.Prim("float32"))
+ _check_inference(bb, binary_cmp_op(x, y), relax.PrimStructInfo("bool"))
+ _check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo("bool"))
+
+
[email protected](reason="Not yet implemented")
+def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value(
+ binary_cmp_op: Callable, tir_cmp_op
+):
+ bb = relax.BlockBuilder()
+
+ tir_x = tir.Var("tir_x", "float32")
+ tir_y = tir.Var("tir_y", "float32")
+
+ x = relax.Var("x", R.Prim(value=tir_x))
+ y = relax.Var("y", R.Prim(value=tir_y))
+
+ _check_inference(bb, binary_cmp_op(x, y),
relax.PrimStructInfo(value=tir_cmp_op(tir_x, tir_y)))
+ _check_inference(bb, binary_cmp_op(y, x),
relax.PrimStructInfo(value=tir_cmp_op(tir_y, tir_x)))
+
+
def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable):
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
@@ -216,7 +282,7 @@ def
test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3), "float32"))
y = relax.Var("y", R.Tensor((2, 3), "int32"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(binary_arith_op(x, y))
@@ -224,7 +290,7 @@ def
test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callab
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm")))
y = relax.Var("y", R.Tensor((2, 3), "int32", VDevice("cuda")))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(binary_arith_op(x, y))
@@ -245,9 +311,9 @@ def
test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable):
x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
y = relax.Var("y", R.Tensor((2, 3), "float32"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(binary_arith_op(x0, y))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(binary_arith_op(x1, y))
diff --git a/tests/python/relax/test_op_nn_convolution.py
b/tests/python/relax/test_op_nn_convolution.py
index 55e35ee203..588dc9b1b1 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -386,7 +386,7 @@ def test_conv1d_dtype_mismatch():
x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
w = relax.Var("w", R.Tensor((4, 3, 3), "int8"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(relax.op.nn.conv1d(x, w))
@@ -744,7 +744,7 @@ def test_conv1d_transpose_dtype_mismatch():
x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
w = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(relax.op.nn.conv1d_transpose(x, w))
@@ -1141,7 +1141,7 @@ def test_conv2d_dtype_mismatch():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(relax.op.nn.conv2d(x, w))
@@ -1533,7 +1533,7 @@ def test_conv2d_transpose_dtype_mismatch():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
w = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(relax.op.nn.conv2d_transpose(x, w))
diff --git a/tests/python/relax/test_op_search.py
b/tests/python/relax/test_op_search.py
index 21f022d9eb..e67ef442f9 100644
--- a/tests/python/relax/test_op_search.py
+++ b/tests/python/relax/test_op_search.py
@@ -262,9 +262,9 @@ def test_where_infer_struct_info_dtype_mismatch():
x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
y1 = relax.Var("y", R.Tensor((2, 3), "float32"))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(relax.op.where(cond, x0, y0))
- with pytest.raises(TVMError):
+ with pytest.raises(TypeError):
bb.normalize(relax.op.where(cond, x1, y1))
diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py
b/tests/python/relax/test_transform_legalize_ops_binary.py
index d71a248b25..7b94057824 100644
--- a/tests/python/relax/test_transform_legalize_ops_binary.py
+++ b/tests/python/relax/test_transform_legalize_ops_binary.py
@@ -17,7 +17,7 @@
import tvm
from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I, relax as R, tir as T
import tvm.testing
@@ -164,6 +164,44 @@ def test_add_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_add_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.add(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.add, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def add(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] + rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_divide():
# fmt: off
@tvm.script.ir_module
@@ -303,6 +341,44 @@ def test_divide_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_divide_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.divide(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.divide, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def divide(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] / rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_floor_divide():
# fmt: off
@tvm.script.ir_module
@@ -442,6 +518,44 @@ def test_floor_divide_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_floordiv_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.floor_divide(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.floor_divide, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def floor_divide(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_floordiv"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = T.floor(lhs[vi, vj, vk] / rhs)
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_multiply():
# fmt: off
@tvm.script.ir_module
@@ -519,6 +633,44 @@ def test_multiply_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_multiply_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.multiply(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.multiply, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def multiply(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] * rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_power():
# fmt: off
@tvm.script.ir_module
@@ -599,6 +751,44 @@ def test_power_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_power_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.power(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.power, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def power(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_power"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = T.pow(lhs[vi, vj, vk], rhs)
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_subtract():
# fmt: off
@tvm.script.ir_module
@@ -676,6 +866,44 @@ def test_subtract_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_subtract_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.subtract(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.subtract, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def subtract(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] - rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
##################### Binary comparison #####################
@@ -818,6 +1046,44 @@ def test_equal_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_equal_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.equal(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.equal, (x, y), R.Tensor([64, 32, 16],
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def equal(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] == rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_greater():
# fmt: off
@tvm.script.ir_module
@@ -957,6 +1223,44 @@ def test_greater_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_greater_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.greater(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.greater, (x, y), R.Tensor([64, 32, 16],
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def greater(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = rhs < lhs[vi, vj, vk]
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_greater_equal():
# fmt: off
@tvm.script.ir_module
@@ -1034,6 +1338,44 @@ def test_greater_equal_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_greater_equal_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.greater_equal(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.greater_equal, (x, y), R.Tensor([64, 32, 16],
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def greater_equal(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = rhs <= lhs[vi, vj, vk]
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_less():
# fmt: off
@tvm.script.ir_module
@@ -1111,6 +1453,44 @@ def test_less_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_less_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.less(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.less, (x, y), R.Tensor([64, 32, 16],
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def less(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] < rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_less_equal():
# fmt: off
@tvm.script.ir_module
@@ -1250,6 +1630,44 @@ def test_less_equal_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_less_equal_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.less_equal(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.less_equal, (x, y), R.Tensor([64, 32, 16],
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def less_equal(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] <= rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_not_equal():
# fmt: off
@tvm.script.ir_module
@@ -1327,6 +1745,44 @@ def test_not_equal_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_not_equal_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.not_equal(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.not_equal, (x, y), R.Tensor([64, 32, 16],
dtype="bool"))
+ return gv
+
+ @T.prim_func(private=True)
+ def not_equal(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = lhs[vi, vj, vk] != rhs
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_maximum():
# fmt: off
@tvm.script.ir_module
@@ -1467,6 +1923,44 @@ def test_maximum_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_max_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.maximum(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.maximum, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def maximum(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = T.max(lhs[vi, vj, vk], rhs)
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
def test_minimum():
# fmt: off
@tvm.script.ir_module
@@ -1607,5 +2101,43 @@ def test_minimum_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_min_primvalue():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ gv = R.minimum(x, y)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([64, 32, 16], "float32"),
+ y: R.Prim("float32"),
+ ):
+ cls = Expected
+ gv = R.call_tir(cls.minimum, (x, y), R.Tensor([64, 32, 16],
dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def minimum(
+ lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
+ rhs: T.float32,
+ output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)],
"float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j, k in T.grid(*lhs.shape):
+ with T.block("T_add"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ output[vi, vj, vk] = T.min(lhs[vi, vj, vk], rhs)
+
+ After = LegalizeOps()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()