This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 81b9c17436 [Unity] Remove attributes of relax.print, assert and unique
(#14101)
81b9c17436 is described below
commit 81b9c17436c33f259ede9d5cf690b97fa3fda490
Author: Yong Wu <[email protected]>
AuthorDate: Thu Feb 23 07:32:21 2023 -0800
[Unity] Remove attributes of relax.print, assert and unique (#14101)
Remove the attributes of operators assert, print and unique.
Use PrimValue as substitute.
Co-authored-by: Steven S. Lyubomirsky
[[email protected]](mailto:[email protected])
Co-authored-by: Prakalp Srivastava
[[email protected]](mailto:[email protected])
---
include/tvm/relax/attrs/set.h | 62 ---------------
include/tvm/relax/op_attr_types.h | 21 -----
python/tvm/relax/op/base.py | 28 ++++---
python/tvm/relax/op/builtin/builtin.py | 16 +++-
python/tvm/relax/op/op_attrs.py | 5 --
python/tvm/relax/op/set.py | 33 +++++---
src/relax/backend/vm/codegen_vm.cc | 24 +++---
src/relax/op/op.cc | 72 ++++++++++++++++-
src/relax/op/tensor/set.cc | 80 +++++++++++++------
src/relax/op/tensor/set.h | 7 +-
tests/python/relax/test_relax_operators.py | 117 +++++++++++++++++++++++++++-
tests/python/relax/test_tvmscript_parser.py | 4 +-
12 files changed, 307 insertions(+), 162 deletions(-)
diff --git a/include/tvm/relax/attrs/set.h b/include/tvm/relax/attrs/set.h
deleted file mode 100644
index 3fae7646ff..0000000000
--- a/include/tvm/relax/attrs/set.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/relax/attrs/set.h
- * \brief Attributes for set operators.
- */
-#ifndef TVM_RELAX_ATTRS_SET_H_
-#define TVM_RELAX_ATTRS_SET_H_
-
-#include <tvm/relax/expr.h>
-
-namespace tvm {
-namespace relax {
-
-/*! \brief Attributes used in unique operator */
-struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
- bool sorted;
- bool return_index;
- bool return_inverse;
- bool return_counts;
- Optional<Integer> axis;
-
- TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") {
- TVM_ATTR_FIELD(sorted).describe(
- "Whether to sort the unique elements in ascending order before
returning as output.");
- TVM_ATTR_FIELD(return_index)
- .describe(
- "Whether to return an additional tensor with indices for where
elements in the unique "
- "tensor come from the original input.");
- TVM_ATTR_FIELD(return_inverse)
- .describe(
- "Whether to return an additional tensor with indices for where
elements in the "
- "original input ended up in the returned unique list.");
- TVM_ATTR_FIELD(return_counts)
- .describe("Whether to return an additional tensor with counts of each
unique elements");
- TVM_ATTR_FIELD(axis).describe(
- "The dimension to apply unique. If it is NullOpt, the unique values of
the flattened input "
- "is are returned.");
- }
-}; // struct UniqueAttrs
-
-} // namespace relax
-} // namespace tvm
-
-#endif // TVM_RELAX_ATTRS_SET_H_
diff --git a/include/tvm/relax/op_attr_types.h
b/include/tvm/relax/op_attr_types.h
index a34cf251dc..413d3e0499 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -58,27 +58,6 @@ using FCallPacked = String;
*/
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const
Call& call)>;
-struct PrintAttrs : public tvm::AttrsNode<PrintAttrs> {
- std::string format;
- TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") {
- TVM_ATTR_FIELD(format)
- .describe("Python-style format string to use for displaying the input.
Ignored if empty.")
- .set_default("");
- }
-};
-
-struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
- std::string format;
- TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") {
- TVM_ATTR_FIELD(format)
- .describe(
- "Python-style format string to use for displaying "
- "an error message if the assert fails. "
- "Ignored if empty.")
- .set_default("");
- }
-};
-
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index d76b155beb..0b298679c1 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -22,7 +22,7 @@ import tvm
from tvm.runtime.object import Object
from . import _ffi_api
-from ..expr import Expr, ShapeExpr, Call, ExternFunc
+from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
@@ -199,7 +199,7 @@ def render_object(val: tvm.Object) -> str:
ret: str
A string representing the value, ideally human-readable
"""
- if isinstance(val, tvm.runtime.ndarray.NDArray):
+ if isinstance(val, tvm.nd.NDArray):
return str(val)
# no pretty-printer by default, so if we don't handle this,
# then we can't look inside tuples
@@ -211,6 +211,9 @@ def render_object(val: tvm.Object) -> str:
if val.tag == 0:
return f"({fields})"
return f"ADT(tag={val.tag}, fields=[{fields}])"
+ if isinstance(val, tvm.ir.Array):
+ fields = ", ".join([render_object(val[i]) for i in range(len(val))])
+ return f"({fields})"
return str(val)
@@ -240,7 +243,7 @@ def relax_print(format_str: str, *format_args: tvm.Object)
-> None:
py_print(format_str.format(*val_strs))
-def print(*values: List[Expr], format: str = "") -> Expr:
+def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr:
"""Print op to print the values
Parameters
@@ -248,14 +251,17 @@ def print(*values: List[Expr], format: str = "") -> Expr:
values : List[Expr]
The values to print.
- format_str: str
- The format string.
+ format: Union[str, Expr]
+ The format string or StringImm.
Returns
-------
result : Expr
A relax Call, which will print the value during runtime.
"""
+ if isinstance(format, str):
+ format = StringImm(format)
+
return _ffi_api.print(values, format) # type: ignore # pylint:
disable=no-member
@@ -289,7 +295,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str,
*format_args: tvm.Ob
)
# should be guaranteed by the type system
- if not isinstance(condition, tvm.runtime.ndarray.NDArray):
+ if not isinstance(condition, tvm.nd.NDArray):
raise ValueError(f"The condition must be an NDArray, but given a
{type(condition)}.")
# may happen if the original program had unknown shape or dtype for the
tensor's type
@@ -313,7 +319,9 @@ def relax_assert_op(condition: tvm.Object, format_str: str,
*format_args: tvm.Ob
def assert_op(
- condition: Expr, format_args: Optional[Union[Expr, List[Expr]]] = None,
format: str = ""
+ condition: Expr,
+ format_args: Optional[Union[Expr, List[Expr]]] = None,
+ format: Union[str, Expr] = "",
) -> Expr:
"""
Create a call to Relax's assert_op operation (`assert` is reserved in
Python,
@@ -327,8 +335,8 @@ def assert_op(
format_args: Optional[Union[Expr, List[Expr]]]
Format arguments for the error message if the condition fails.
- format_str: str
- The format string for the error message.
+ format: Union[str, Expr]
+ The format string or StringImm for the error message.
Returns
-------
@@ -339,6 +347,8 @@ def assert_op(
format_args = []
if isinstance(format_args, Expr): # type: ignore
format_args = [format_args]
+ if isinstance(format, str):
+ format = StringImm(format)
return _ffi_api.assert_op(condition, format_args, format) # type: ignore
diff --git a/python/tvm/relax/op/builtin/builtin.py
b/python/tvm/relax/op/builtin/builtin.py
index 0afe6a42d0..43bbd461bc 100644
--- a/python/tvm/relax/op/builtin/builtin.py
+++ b/python/tvm/relax/op/builtin/builtin.py
@@ -15,13 +15,16 @@
# specific language governing permissions and limitations
"""The builtin Relax operators."""
-from ...expr import Call, Expr
+from typing import Union
+from ...expr import Call, Expr, PrimValue, DataTypeImm
from ...utils import args_converter
from . import _ffi_api
@args_converter.auto
-def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call:
+def alloc_tensor(
+ shape: Expr, dtype: Union[str, Expr], runtime_device_index: Union[int,
Expr]
+) -> Call:
"""Construct a Call to allocate a tensor with specific shape, dtype,
runtime_device_index.
Parameters
@@ -29,10 +32,10 @@ def alloc_tensor(shape: Expr, dtype: str,
runtime_device_index: int) -> Call:
shape : Expr
The shape of the tensor to be allocated.
- dtype : str
+ dtype : Union[str, Expr]
The datatype of the tensor to be allocated.
- runtime_device_index : int
+ runtime_device_index : Union[int, Expr]
The device index indicating on which device the tensor is to be
allocated at runtime.
Index -1 is reserved for the host device.
@@ -41,4 +44,9 @@ def alloc_tensor(shape: Expr, dtype: str,
runtime_device_index: int) -> Call:
result : Call
A relax Call, which gets the allocated tensor.
"""
+ if isinstance(dtype, str):
+ dtype = DataTypeImm(dtype)
+ if isinstance(runtime_device_index, int):
+ runtime_device_index = PrimValue(runtime_device_index)
+
return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type:
ignore
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index efad5d98f0..ff89d7c903 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -122,8 +122,3 @@ class LayoutTransformAttrs(Attrs):
@tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
class Resize2DAttrs(Attrs):
"""Attributes used in image resize2d operator"""
-
-
-@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
-class UniqueAttrs(Attrs):
- """Attributes used for the unique operator"""
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index b7ee0f3811..4d106ad6d2 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -16,22 +16,22 @@
# under the License.
# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument
"""Set operators."""
-from typing import Optional
+from typing import Optional, Union
import numpy as np # type: ignore
import tvm
from . import _ffi_api
-from ..expr import Expr
+from ..expr import Expr, PrimValue
def unique(
x: Expr,
- sorted: bool = True,
- return_index: bool = False,
- return_inverse: bool = False,
- return_counts: bool = False,
- axis: Optional[int] = None,
+ sorted: Union[bool, Expr] = True,
+ return_index: Union[bool, Expr] = False,
+ return_inverse: Union[bool, Expr] = False,
+ return_counts: Union[bool, Expr] = False,
+ axis: Optional[Union[int, Expr]] = None,
) -> Expr:
"""Find the unique elements in a given tensor.
In addition, it optionally returns
@@ -44,19 +44,19 @@ def unique(
x : relax.Expr
The input tensor.
- sorted : bool
+ sorted : Union[bool, Expr]
Whether to sort the unique elements in ascending order before
returning as output.
- return_index : bool
+ return_index : Union[bool, Expr]
Whether to return an additional tensor with indices for where elements
in
the unique tensor come from the original input.
- return_inverse : bool
+ return_inverse : Union[bool, Expr]
Whether to return an additional tensor with indices for where elements
in
the original input ended up in the returned unique list.
- return_counts : bool
+ return_counts : Union[bool, Expr]
Whether to return an additional tensor with counts of each unique
elements.
axis : Optional
@@ -69,6 +69,16 @@ def unique(
The created relax call with
"""
+ if isinstance(sorted, bool):
+ sorted = PrimValue(sorted)
+ if isinstance(return_index, bool):
+ return_index = PrimValue(return_index)
+ if isinstance(return_inverse, bool):
+ return_inverse = PrimValue(return_inverse)
+ if isinstance(return_counts, bool):
+ return_counts = PrimValue(return_counts)
+ if axis and isinstance(axis, int):
+ axis = PrimValue(axis)
return _ffi_api.unique( # type: ignore
x, sorted, return_index, return_inverse, return_counts, axis
)
@@ -81,7 +91,6 @@ def numpy_unique(
return_index: int,
return_inverse: int,
return_counts: int,
- axis: Optional[int],
) -> tvm.nd.array:
"""Returns the unique elements of the input tensor.
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 1782f1107a..da0ca3a0b5 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -148,7 +148,14 @@ class CodeGenVM : public
ExprFunctor<Instruction::Arg(const Expr&)> {
// allocate dst register.
RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister :
NewRegister();
if (call->op.as<OpNode>()) {
- if (call_node->op == call_builtin_with_ctx_op_) {
+ // special case generate for the intrinsics whose attribute fields
+ // cannot be represented by args in the CallNode
+ FCallPacked name = GetPackedFuncName(call);
+ if (!name.empty()) {
+ // If the operator has a registered packed function implementation,
emit call to that packed
+ // function.
+ EmitPackedFuncCall(call, name, dst_reg);
+ } else if (call_node->op == call_builtin_with_ctx_op_) {
// TODO(relax-team) migrate most handling of op to
// directly map to call_builtin_with_ctx before codegen and simplify
vm codegen.
EmitCallBuiltinWithCtx(call, dst_reg);
@@ -355,22 +362,9 @@ class CodeGenVM : public
ExprFunctor<Instruction::Arg(const Expr&)> {
builder_->EmitCall(func, args, dst_reg);
}
- // TODO(relax-team) revisit after PrimValue.
- // Emit the `call_node` attributes as constants and append these constants
to `args` vector.
- void AppendAttrsAsConstants(const Call& call_node,
std::vector<Instruction::Arg>& args) {
- auto attrs = call_node->attrs;
- if (!attrs.defined()) return;
-
- LOG(FATAL) << "Support for attributes of Op " << call_node->op
- << " has not been implemented yet.";
- return;
- }
-
- // Emits call to packed function `name` with arguments copied over from
`call_node` args and
- // attributes.
+ // Emits call to packed function `name` with arguments copied over from
`call_node` args
void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name,
RegName dst_reg) {
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
- AppendAttrsAsConstants(call_node, args);
builder_->EmitCall(name, args, dst_reg);
}
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index f478871e21..21d692b6a4 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -141,6 +141,70 @@ Expr MakeCallNullValue() {
TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue);
+// print
+
+RELAY_REGISTER_OP("relax.print")
+ .set_num_inputs(-1)
+ .add_argument("vals", "Array<Expr>",
+ "The first value is Python-style format string to use to
print. The others "
+ "are values to print")
+ .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
+ .set_attr<FCallPacked>("FCallPacked", "relax.run.print");
+
+Expr MakePrint(Array<Expr> vals, StringImm format) {
+ Array<Expr> params;
+ params.push_back(format);
+ for (const auto val : vals) {
+ params.push_back(val);
+ }
+ static const Op& op = Op::Get("relax.print");
+ return Call(op, params);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint);
+
+// assert_op
+
+// can't actually name it assert or else Python will consider it a syntax error
+
+StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) {
+ // Ensure that the condition argument is a boolean scalar.
+ // Also permitted is a tensor with unknown shape and unknown dtype
+ // (checked dynamically in that case). Returns void.
+ if (call->args.size() < 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Assert must have at least one argument (the
condition).");
+ }
+ StructInfo arg_struct_info = GetStructInfo(call->args[0]);
+ if (!IsBoolStructInfo(arg_struct_info)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The argument to assert must be a boolean scalar, but
received "
+ << arg_struct_info);
+ }
+ return ReturnVoidStructInfo(call, ctx);
+}
+
+RELAY_REGISTER_OP("relax.assert_op")
+ .set_num_inputs(-1)
+ .add_argument("vals", "Array<Expr>",
+ "The first value is used as the assertion condition. The
second value is "
+ "Python-style format string to use for displaying an error
message, if the "
+ "assert fails. The others are used as format arguments if
there is an error.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferAssertStructInfo)
+ .set_attr<FCallPacked>("FCallPacked", "relax.run.assert_op");
+
+Expr MakeAssertOp(Expr condition, Array<Expr> vals, StringImm format) {
+ static const Op& op = Op::Get("relax.assert_op");
+ Array<Expr> args = {condition};
+ args.push_back(format);
+ for (auto val : vals) {
+ args.push_back(val);
+ }
+ return Call(op, args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp);
+
// make_closure
RELAY_REGISTER_OP("relax.make_closure")
@@ -213,15 +277,15 @@ StructInfo InferStructInfoAllocateTensor(const Call&
call, const BlockBuilder& c
RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
.set_num_inputs(3)
.add_argument("shape", "Expr", "The shape of the tensor to allocate.")
- .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
- .add_argument("runtime_device_index", "int64_t",
+ .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to
allocate.")
+ .add_argument("runtime_device_index", "PrimValue",
"The device index indicating on which device the tensor is
to be "
"allocated at runtime. Index -1 is reserved for the host
device.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAllocateTensor);
-Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index)
{
+Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue
runtime_device_index) {
static const Op& op = Op::Get("relax.builtin.alloc_tensor");
- return Call(op, {shape, DataTypeImm(dtype),
PrimValue::Int64(runtime_device_index)}, Attrs(), {});
+ return Call(op, {shape, dtype, runtime_device_index}, Attrs(), {});
}
TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index 4d5a274e17..8df0813ed2 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -31,34 +31,55 @@ namespace tvm {
namespace relax {
/* relax.unique */
-TVM_REGISTER_NODE_TYPE(UniqueAttrs);
-
-Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool
return_counts,
- Optional<Integer> axis) {
- ObjectPtr<UniqueAttrs> attrs = make_object<UniqueAttrs>();
- attrs->sorted = sorted;
- attrs->return_index = return_index;
- attrs->return_inverse = return_inverse;
- attrs->return_counts = return_counts;
- attrs->axis = std::move(axis);
+Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue
return_inverse,
+ PrimValue return_counts, Optional<PrimValue> axis) {
static const Op& op = Op::Get("relax.unique");
- return Call(op, {std::move(x)}, Attrs(attrs), {});
+ Call call;
+ if (!axis) {
+ call = Call(op, {std::move(x), sorted, return_index, return_inverse,
return_counts});
+ } else {
+ PrimValue pv_axis = axis.value();
+ call = Call(op, {std::move(x), sorted, return_index, return_inverse,
return_counts, pv_axis});
+ }
+ return call;
}
TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique);
StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) {
- TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
- const auto* attrs = call->attrs.as<UniqueAttrs>();
- if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+ TensorStructInfo data_sinfo =
Downcast<TensorStructInfo>(call->args[0]->struct_info_);
+ PrimValue axis, return_index, return_inverse, return_counts;
+ if (call->args.size() == 6) {
+ if (auto* prim_value_node = call->args[5].as<PrimValueNode>()) {
+ axis = GetRef<PrimValue>(prim_value_node);
+ }
+ }
+ if (!data_sinfo->IsUnknownNdim() && axis.defined()) {
// Normalize the axis for sanity check purpose.
- NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value);
+ if (const auto* axis_int = axis->value.as<IntImmNode>()) {
+ NormalizeAxis(call, ctx, data_sinfo->ndim, axis_int->value);
+ }
}
-
- int n_int_return = static_cast<int>(attrs->return_index) +
- static_cast<int>(attrs->return_inverse) +
- static_cast<int>(attrs->return_counts);
+ ICHECK(call->args[2]->IsInstance<PrimValueNode>());
+ ICHECK(call->args[3]->IsInstance<PrimValueNode>());
+ ICHECK(call->args[4]->IsInstance<PrimValueNode>());
+
+ return_index = Downcast<PrimValue>(call->args[2]);
+ return_inverse = Downcast<PrimValue>(call->args[3]);
+ return_counts = Downcast<PrimValue>(call->args[4]);
+
+ auto f_convert_to_int64 = [](const PrimExpr& value) {
+ CHECK(value->IsInstance<IntImmNode>())
+ << value << " expects to be IntImm, but gets " << value->GetTypeKey();
+ const auto* val_node = value.as<IntImmNode>();
+ auto val_imm = GetRef<IntImm>(val_node);
+ return val_imm->value;
+ };
+
+ int64_t n_int_return = f_convert_to_int64(return_index->value) +
+ f_convert_to_int64(return_inverse->value) +
+ f_convert_to_int64(return_counts->value);
std::vector<StructInfo> output_sinfo;
output_sinfo.reserve(1 + n_int_return);
@@ -67,7 +88,7 @@ StructInfo InferStructInfoUnique(const Call& call, const
BlockBuilder& ctx) {
if (data_sinfo->ndim == 0) {
output_sinfo.push_back(
TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}),
data_sinfo->dtype));
- } else if (attrs->axis.defined()) {
+ } else if (axis.defined()) {
output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype,
data_sinfo->ndim));
} else {
output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
@@ -93,9 +114,24 @@ StructInfo InferStructInfoUnique(const Call& call, const
BlockBuilder& ctx) {
}
TVM_REGISTER_OP("relax.unique")
- .set_attrs_type<UniqueAttrs>()
- .set_num_inputs(1)
+ .set_num_inputs(6)
.add_argument("x", "Tensor", "The input tensor")
+ .add_argument(
+ "sorted", "Tensor",
+ "Whether to sort the unique elements in ascending order before
returning as output.")
+ .add_argument(
+ "return_index", "Tensor",
+ "Whether to return an additional tensor with indices for where
elements in the unique "
+ "tensor come from the original input.")
+ .add_argument("return_inverse", "Tensor",
+ "Whether to return an additional tensor with indices for
where elements in the "
+ "original input ended up in the returned unique list.")
+ .add_argument("return_counts", "Tensor",
+ "Whether to return an additional tensor with counts of each
unique elements")
+ .add_argument(
+ "axis", "Tensor",
+ "The dimension to apply unique. If it is NullOpt, the unique values of
the flattened input "
+ "are returned.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoUnique)
.set_attr<FCallPacked>("FCallPacked", "relax.run.unique");
diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h
index 83d2619e4d..a5c7ee85bf 100644
--- a/src/relax/op/tensor/set.h
+++ b/src/relax/op/tensor/set.h
@@ -24,16 +24,13 @@
#ifndef TVM_RELAX_OP_TENSOR_SET_H_
#define TVM_RELAX_OP_TENSOR_SET_H_
-#include <tvm/relax/attrs/set.h>
-
#include "../op_common.h"
namespace tvm {
namespace relax {
-Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool
return_counts,
- Optional<Integer> axis);
-
+Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue
return_inverse,
+ PrimValue return_counts, Optional<PrimValue> axis);
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 7b0b98fea9..c66a5729fd 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -26,13 +26,128 @@ from tvm._ffi.base import TVMError
from tvm.script import relax as R
[email protected]_module
+class InputModule:
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "int64")):
+ y = R.unique(x, sorted=False)
+ y_sorted = R.unique(x)
+ return y, y_sorted
+
+
def run_cpu(mod, func_name, *input):
target = tvm.target.Target("llvm")
- ex = relax.vm.build(mod, target)
+ ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
return vm[func_name](*input)
+def test_unique():
+
+ # TODO(prakalp): also add test for compiling and running on cuda device.
+ data_numpy = np.random.randint(0, 16, (16, 16))
+ data = tvm.nd.array(data_numpy)
+ result, result_sorted = run_cpu(InputModule, "foo", data)
+
+ expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
+ expected_output = [data_numpy.flatten()[index] for index in
sorted(indices, reverse=True)]
+
+ np.testing.assert_array_equal(expected_output_sorted,
result_sorted.numpy())
+ np.testing.assert_array_equal(expected_output, result.numpy())
+
+
[email protected]_module
+class PrintTest:
+ @R.function
+ def foo(x: R.Tensor((), "int32")):
+ # results have to be bound, but we don't use them
+ # TODO: We should allow calls whose results are not bound for side
effects;
+ # it would be easy syntactic sugar to add.
+ p1 = R.print(x)
+ p2 = R.print(x, format="Number: {}")
+ t = (x, x)
+ p3 = R.print(t, format="Tuple: {}")
+ p4 = R.print(x, t)
+ p5 = R.print(x, x, format="Custom print: {} {}")
+ p6 = R.print(x, t, format="Another print: {} {}")
+ return x
+
+
+def test_print():
+ try:
+ stdout = sys.stdout
+ with tempfile.TemporaryFile(mode="w+") as test_out:
+ sys.stdout = test_out
+ run_cpu(PrintTest, "foo",
tvm.nd.array(np.array(1).astype("int32")))
+ test_out.seek(0)
+ printed_text = str(test_out.read())
+ expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1
1\nAnother print: 1 (1, 1)\n"
+ assert printed_text in expected, ("printed_text is ", printed_text)
+ finally:
+ sys.stdout = stdout
+
+
[email protected]_module
+class AssertOpTest:
+ @R.function
+ def passes(x: R.Tensor((), "int32")):
+ p1 = R.assert_op(relax.const(True))
+ return x
+
+ @R.function
+ def pass_with_args(x: R.Tensor((), "int32")):
+ p1 = R.assert_op(relax.const(True), x, format="You won't see me")
+ return x
+
+ @R.function
+ def simple_fail(x: R.Tensor((), "int32")):
+ p1 = R.assert_op(relax.const(False))
+ return x
+
+ @R.function
+ def fail_with_message(x: R.Tensor((), "int32")):
+ p1 = R.assert_op(relax.const(False), format="I failed...")
+ return x
+
+ @R.function
+ def fail_with_args(x: R.Tensor((), "int32")):
+ # no format
+ p1 = R.assert_op(relax.const(False), [x, x])
+ return x
+
+ @R.function
+ def fail_with_formatted_message(x: R.Tensor((), "int32")):
+ p1 = R.assert_op(relax.const(False), x, format="Number: {}")
+ return x
+
+
+def test_assert_op():
+ def check_assertion_error(func_name, func_arg, expected_message):
+ passed = False
+ try:
+ run_cpu(AssertOpTest, func_name, func_arg)
+ passed = True
+ except TVMError as e:
+ # TVM will print out a TVMError that will contain the
+ # generated error at the bottom of a stack trace
+ assert "AssertionError" in e.args[0]
+ assert expected_message in e.args[0]
+ assert not passed
+
+ run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32")))
+ run_cpu(AssertOpTest, "pass_with_args",
tvm.nd.array(np.array(2).astype("int32")))
+ check_assertion_error(
+ "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion
Failed"
+ )
+ check_assertion_error(
+ "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I
failed..."
+ )
+ check_assertion_error("fail_with_args",
tvm.nd.array(np.array(5).astype("int32")), "5, 5")
+ check_assertion_error(
+ "fail_with_formatted_message",
tvm.nd.array(np.array(6).astype("int32")), "Number: 6"
+ )
+
+
@tvm.script.ir_module
class ShapeOfTest:
@R.function
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index b458b290ec..7724c8e761 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -213,8 +213,8 @@ def test_relax_base_op():
alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)),
"float32", 0))
shape = bb.emit(relax.op.shape_of(alloc))
bb.emit_func_output(shape)
- # todo(yongwww): comment this check because 0 was changed to
R.prim_value(0) in the printed IR
- # _check(foo, bb.get()["foo"])
+
+ _check(foo, bb.get()["foo"])
def test_symbolic_shape():