This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 81b9c17436 [Unity] Remove attributes of relax.print, assert and unique 
(#14101)
81b9c17436 is described below

commit 81b9c17436c33f259ede9d5cf690b97fa3fda490
Author: Yong Wu <[email protected]>
AuthorDate: Thu Feb 23 07:32:21 2023 -0800

    [Unity] Remove attributes of relax.print, assert and unique (#14101)
    
    Remove the attributes of operators assert, print and unique.
    Use PrimValue as substitute.
    
    Co-authored-by: Steven S. Lyubomirsky 
[[email protected]](mailto:[email protected])
    Co-authored-by: Prakalp Srivastava 
[[email protected]](mailto:[email protected])
---
 include/tvm/relax/attrs/set.h               |  62 ---------------
 include/tvm/relax/op_attr_types.h           |  21 -----
 python/tvm/relax/op/base.py                 |  28 ++++---
 python/tvm/relax/op/builtin/builtin.py      |  16 +++-
 python/tvm/relax/op/op_attrs.py             |   5 --
 python/tvm/relax/op/set.py                  |  33 +++++---
 src/relax/backend/vm/codegen_vm.cc          |  24 +++---
 src/relax/op/op.cc                          |  72 ++++++++++++++++-
 src/relax/op/tensor/set.cc                  |  80 +++++++++++++------
 src/relax/op/tensor/set.h                   |   7 +-
 tests/python/relax/test_relax_operators.py  | 117 +++++++++++++++++++++++++++-
 tests/python/relax/test_tvmscript_parser.py |   4 +-
 12 files changed, 307 insertions(+), 162 deletions(-)

diff --git a/include/tvm/relax/attrs/set.h b/include/tvm/relax/attrs/set.h
deleted file mode 100644
index 3fae7646ff..0000000000
--- a/include/tvm/relax/attrs/set.h
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/relax/attrs/set.h
- * \brief Attributes for set operators.
- */
-#ifndef TVM_RELAX_ATTRS_SET_H_
-#define TVM_RELAX_ATTRS_SET_H_
-
-#include <tvm/relax/expr.h>
-
-namespace tvm {
-namespace relax {
-
-/*! \brief Attributes used in unique operator */
-struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
-  bool sorted;
-  bool return_index;
-  bool return_inverse;
-  bool return_counts;
-  Optional<Integer> axis;
-
-  TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") {
-    TVM_ATTR_FIELD(sorted).describe(
-        "Whether to sort the unique elements in ascending order before 
returning as output.");
-    TVM_ATTR_FIELD(return_index)
-        .describe(
-            "Whether to return an additional tensor with indices for where 
elements in the unique "
-            "tensor come from the original input.");
-    TVM_ATTR_FIELD(return_inverse)
-        .describe(
-            "Whether to return an additional tensor with indices for where 
elements in the "
-            "original input ended up in the returned unique list.");
-    TVM_ATTR_FIELD(return_counts)
-        .describe("Whether to return an additional tensor with counts of each 
unique elements");
-    TVM_ATTR_FIELD(axis).describe(
-        "The dimension to apply unique. If it is NullOpt, the unique values of 
the flattened input "
-        "is are returned.");
-  }
-};  // struct UniqueAttrs
-
-}  // namespace relax
-}  // namespace tvm
-
-#endif  // TVM_RELAX_ATTRS_SET_H_
diff --git a/include/tvm/relax/op_attr_types.h 
b/include/tvm/relax/op_attr_types.h
index a34cf251dc..413d3e0499 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -58,27 +58,6 @@ using FCallPacked = String;
  */
 using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const 
Call& call)>;
 
-struct PrintAttrs : public tvm::AttrsNode<PrintAttrs> {
-  std::string format;
-  TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") {
-    TVM_ATTR_FIELD(format)
-        .describe("Python-style format string to use for displaying the input. 
Ignored if empty.")
-        .set_default("");
-  }
-};
-
-struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
-  std::string format;
-  TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") {
-    TVM_ATTR_FIELD(format)
-        .describe(
-            "Python-style format string to use for displaying "
-            "an error message if the assert fails. "
-            "Ignored if empty.")
-        .set_default("");
-  }
-};
-
 }  // namespace relax
 }  // namespace tvm
 #endif  // TVM_RELAX_OP_ATTR_TYPES_H_
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index d76b155beb..0b298679c1 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -22,7 +22,7 @@ import tvm
 from tvm.runtime.object import Object
 
 from . import _ffi_api
-from ..expr import Expr, ShapeExpr, Call, ExternFunc
+from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
 from ..expr import Tuple as RxTuple
 from ..struct_info import StructInfo, TensorStructInfo
 from ...ir import PrimExpr
@@ -199,7 +199,7 @@ def render_object(val: tvm.Object) -> str:
     ret: str
         A string representing the value, ideally human-readable
     """
-    if isinstance(val, tvm.runtime.ndarray.NDArray):
+    if isinstance(val, tvm.nd.NDArray):
         return str(val)
     # no pretty-printer by default, so if we don't handle this,
     # then we can't look inside tuples
@@ -211,6 +211,9 @@ def render_object(val: tvm.Object) -> str:
         if val.tag == 0:
             return f"({fields})"
         return f"ADT(tag={val.tag}, fields=[{fields}])"
+    if isinstance(val, tvm.ir.Array):
+        fields = ", ".join([render_object(val[i]) for i in range(len(val))])
+        return f"({fields})"
     return str(val)
 
 
@@ -240,7 +243,7 @@ def relax_print(format_str: str, *format_args: tvm.Object) 
-> None:
         py_print(format_str.format(*val_strs))
 
 
-def print(*values: List[Expr], format: str = "") -> Expr:
+def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr:
     """Print op to print the values
 
     Parameters
@@ -248,14 +251,17 @@ def print(*values: List[Expr], format: str = "") -> Expr:
     values : List[Expr]
         The values to print.
 
-    format_str: str
-        The format string.
+    format: Union[str, Expr]
+        The format string or StringImm.
 
     Returns
     -------
     result : Expr
         A relax Call, which will print the value during runtime.
     """
+    if isinstance(format, str):
+        format = StringImm(format)
+
     return _ffi_api.print(values, format)  # type: ignore # pylint: 
disable=no-member
 
 
@@ -289,7 +295,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, 
*format_args: tvm.Ob
         )
 
     # should be guaranteed by the type system
-    if not isinstance(condition, tvm.runtime.ndarray.NDArray):
+    if not isinstance(condition, tvm.nd.NDArray):
         raise ValueError(f"The condition must be an NDArray, but given a 
{type(condition)}.")
 
     # may happen if the original program had unknown shape or dtype for the 
tensor's type
@@ -313,7 +319,9 @@ def relax_assert_op(condition: tvm.Object, format_str: str, 
*format_args: tvm.Ob
 
 
 def assert_op(
-    condition: Expr, format_args: Optional[Union[Expr, List[Expr]]] = None, 
format: str = ""
+    condition: Expr,
+    format_args: Optional[Union[Expr, List[Expr]]] = None,
+    format: Union[str, Expr] = "",
 ) -> Expr:
     """
     Create a call to Relax's assert_op operation (`assert` is reserved in 
Python,
@@ -327,8 +335,8 @@ def assert_op(
     format_args: Optional[Union[Expr, List[Expr]]]
         Format arguments for the error message if the condition fails.
 
-    format_str: str
-        The format string for the error message.
+    format: Union[str, Expr]
+        The format string or StringImm for the error message.
 
     Returns
     -------
@@ -339,6 +347,8 @@ def assert_op(
         format_args = []
     if isinstance(format_args, Expr):  # type: ignore
         format_args = [format_args]
+    if isinstance(format, str):
+        format = StringImm(format)
     return _ffi_api.assert_op(condition, format_args, format)  # type: ignore
 
 
diff --git a/python/tvm/relax/op/builtin/builtin.py 
b/python/tvm/relax/op/builtin/builtin.py
index 0afe6a42d0..43bbd461bc 100644
--- a/python/tvm/relax/op/builtin/builtin.py
+++ b/python/tvm/relax/op/builtin/builtin.py
@@ -15,13 +15,16 @@
 # specific language governing permissions and limitations
 """The builtin Relax operators."""
 
-from ...expr import Call, Expr
+from typing import Union
+from ...expr import Call, Expr, PrimValue, DataTypeImm
 from ...utils import args_converter
 from . import _ffi_api
 
 
 @args_converter.auto
-def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call:
+def alloc_tensor(
+    shape: Expr, dtype: Union[str, Expr], runtime_device_index: Union[int, 
Expr]
+) -> Call:
     """Construct a Call to allocate a tensor with specific shape, dtype, 
runtime_device_index.
 
     Parameters
@@ -29,10 +32,10 @@ def alloc_tensor(shape: Expr, dtype: str, 
runtime_device_index: int) -> Call:
     shape : Expr
         The shape of the tensor to be allocated.
 
-    dtype : str
+    dtype : Union[str, Expr]
         The datatype of the tensor to be allocated.
 
-    runtime_device_index : int
+    runtime_device_index : Union[int, Expr]
         The device index indicating on which device the tensor is to be 
allocated at runtime.
         Index -1 is reserved for the host device.
 
@@ -41,4 +44,9 @@ def alloc_tensor(shape: Expr, dtype: str, 
runtime_device_index: int) -> Call:
     result : Call
         A relax Call, which gets the allocated tensor.
     """
+    if isinstance(dtype, str):
+        dtype = DataTypeImm(dtype)
+    if isinstance(runtime_device_index, int):
+        runtime_device_index = PrimValue(runtime_device_index)
+
     return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index)  # type: 
ignore
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index efad5d98f0..ff89d7c903 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -122,8 +122,3 @@ class LayoutTransformAttrs(Attrs):
 @tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
 class Resize2DAttrs(Attrs):
     """Attributes used in image resize2d operator"""
-
-
-@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
-class UniqueAttrs(Attrs):
-    """Attributes used for the unique operator"""
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index b7ee0f3811..4d106ad6d2 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -16,22 +16,22 @@
 # under the License.
 # pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument
 """Set operators."""
-from typing import Optional
+from typing import Optional, Union
 
 import numpy as np  # type: ignore
 import tvm
 
 from . import _ffi_api
-from ..expr import Expr
+from ..expr import Expr, PrimValue
 
 
 def unique(
     x: Expr,
-    sorted: bool = True,
-    return_index: bool = False,
-    return_inverse: bool = False,
-    return_counts: bool = False,
-    axis: Optional[int] = None,
+    sorted: Union[bool, Expr] = True,
+    return_index: Union[bool, Expr] = False,
+    return_inverse: Union[bool, Expr] = False,
+    return_counts: Union[bool, Expr] = False,
+    axis: Optional[Union[int, Expr]] = None,
 ) -> Expr:
     """Find the unique elements in a given tensor.
     In addition, it optionally returns
@@ -44,19 +44,19 @@ def unique(
     x : relax.Expr
         The input tensor.
 
-    sorted : bool
+    sorted : Union[bool, Expr]
         Whether to sort the unique elements in ascending order before
         returning as output.
 
-    return_index : bool
+    return_index : Union[bool, Expr]
         Whether to return an additional tensor with indices for where elements 
in
         the unique tensor come from the original input.
 
-    return_inverse : bool
+    return_inverse : Union[bool, Expr]
         Whether to return an additional tensor with indices for where elements 
in
         the original input ended up in the returned unique list.
 
-    return_counts : bool
+    return_counts : Union[bool, Expr]
         Whether to return an additional tensor with counts of each unique 
elements.
 
     axis : Optional
@@ -69,6 +69,16 @@ def unique(
         The created relax call with
     """
 
+    if isinstance(sorted, bool):
+        sorted = PrimValue(sorted)
+    if isinstance(return_index, bool):
+        return_index = PrimValue(return_index)
+    if isinstance(return_inverse, bool):
+        return_inverse = PrimValue(return_inverse)
+    if isinstance(return_counts, bool):
+        return_counts = PrimValue(return_counts)
+    if axis and isinstance(axis, int):
+        axis = PrimValue(axis)
     return _ffi_api.unique(  # type: ignore
         x, sorted, return_index, return_inverse, return_counts, axis
     )
@@ -81,7 +91,6 @@ def numpy_unique(
     return_index: int,
     return_inverse: int,
     return_counts: int,
-    axis: Optional[int],
 ) -> tvm.nd.array:
     """Returns the unique elements of the input tensor.
 
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index 1782f1107a..da0ca3a0b5 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -148,7 +148,14 @@ class CodeGenVM : public 
ExprFunctor<Instruction::Arg(const Expr&)> {
     // allocate dst register.
     RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : 
NewRegister();
     if (call->op.as<OpNode>()) {
-      if (call_node->op == call_builtin_with_ctx_op_) {
+      // special case generate for the intrinsics whose attribute fields
+      // cannot be represented by args in the CallNode
+      FCallPacked name = GetPackedFuncName(call);
+      if (!name.empty()) {
+        // If the operator has a registered packed function implementation, 
emit call to that packed
+        // function.
+        EmitPackedFuncCall(call, name, dst_reg);
+      } else if (call_node->op == call_builtin_with_ctx_op_) {
         // TODO(relax-team) migrate most handling of op to
         // directly map to call_builtin_with_ctx before codegen and simplify 
vm codegen.
         EmitCallBuiltinWithCtx(call, dst_reg);
@@ -355,22 +362,9 @@ class CodeGenVM : public 
ExprFunctor<Instruction::Arg(const Expr&)> {
     builder_->EmitCall(func, args, dst_reg);
   }
 
-  // TODO(relax-team) revisit after PrimValue.
-  // Emit the `call_node` attributes as constants and append these constants 
to `args` vector.
-  void AppendAttrsAsConstants(const Call& call_node, 
std::vector<Instruction::Arg>& args) {
-    auto attrs = call_node->attrs;
-    if (!attrs.defined()) return;
-
-    LOG(FATAL) << "Support for attributes of Op " << call_node->op
-               << " has not been implemented yet.";
-    return;
-  }
-
-  // Emits call to packed function `name` with arguments copied over from 
`call_node` args and
-  // attributes.
+  // Emits call to packed function `name` with arguments copied over from 
`call_node` args
   void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, 
RegName dst_reg) {
     std::vector<Instruction::Arg> args = VisitArray(call_node->args);
-    AppendAttrsAsConstants(call_node, args);
     builder_->EmitCall(name, args, dst_reg);
   }
 
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index f478871e21..21d692b6a4 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -141,6 +141,70 @@ Expr MakeCallNullValue() {
 
 TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue);
 
+// print
+
+RELAY_REGISTER_OP("relax.print")
+    .set_num_inputs(-1)
+    .add_argument("vals", "Array<Expr>",
+                  "The first value is Python-style format string to use to 
print. The others "
+                  "are values to print")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
+    .set_attr<FCallPacked>("FCallPacked", "relax.run.print");
+
+Expr MakePrint(Array<Expr> vals, StringImm format) {
+  Array<Expr> params;
+  params.push_back(format);
+  for (const auto val : vals) {
+    params.push_back(val);
+  }
+  static const Op& op = Op::Get("relax.print");
+  return Call(op, params);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint);
+
+// assert_op
+
+// can't actually name it assert or else Python will consider it a syntax error
+
+StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) {
+  // Ensure that the condition argument is a boolean scalar.
+  // Also permitted is a tensor with unknown shape and unknown dtype
+  // (checked dynamically in that case). Returns void.
+  if (call->args.size() < 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Assert must have at least one argument (the 
condition).");
+  }
+  StructInfo arg_struct_info = GetStructInfo(call->args[0]);
+  if (!IsBoolStructInfo(arg_struct_info)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "The argument to assert must be a boolean scalar, but 
received "
+                     << arg_struct_info);
+  }
+  return ReturnVoidStructInfo(call, ctx);
+}
+
+RELAY_REGISTER_OP("relax.assert_op")
+    .set_num_inputs(-1)
+    .add_argument("vals", "Array<Expr>",
+                  "The first value is used as the assertion condition. The 
second value is "
+                  "Python-style format string to use for displaying an error 
message, if the "
+                  "assert fails. The others are used as format arguments if 
there is an error.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferAssertStructInfo)
+    .set_attr<FCallPacked>("FCallPacked", "relax.run.assert_op");
+
+Expr MakeAssertOp(Expr condition, Array<Expr> vals, StringImm format) {
+  static const Op& op = Op::Get("relax.assert_op");
+  Array<Expr> args = {condition};
+  args.push_back(format);
+  for (auto val : vals) {
+    args.push_back(val);
+  }
+  return Call(op, args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp);
+
 // make_closure
 
 RELAY_REGISTER_OP("relax.make_closure")
@@ -213,15 +277,15 @@ StructInfo InferStructInfoAllocateTensor(const Call& 
call, const BlockBuilder& c
 RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
     .set_num_inputs(3)
     .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
-    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
-    .add_argument("runtime_device_index", "int64_t",
+    .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
+    .add_argument("runtime_device_index", "PrimValue",
                   "The device index indicating on which device the tensor is 
to be "
                   "allocated at runtime. Index -1 is reserved for the host 
device.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAllocateTensor);
 
-Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) 
{
+Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue 
runtime_device_index) {
   static const Op& op = Op::Get("relax.builtin.alloc_tensor");
-  return Call(op, {shape, DataTypeImm(dtype), 
PrimValue::Int64(runtime_device_index)}, Attrs(), {});
+  return Call(op, {shape, dtype, runtime_device_index}, Attrs(), {});
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index 4d5a274e17..8df0813ed2 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -31,34 +31,55 @@ namespace tvm {
 namespace relax {
 
 /* relax.unique */
-TVM_REGISTER_NODE_TYPE(UniqueAttrs);
-
-Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool 
return_counts,
-            Optional<Integer> axis) {
-  ObjectPtr<UniqueAttrs> attrs = make_object<UniqueAttrs>();
-  attrs->sorted = sorted;
-  attrs->return_index = return_index;
-  attrs->return_inverse = return_inverse;
-  attrs->return_counts = return_counts;
-  attrs->axis = std::move(axis);
 
+Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue 
return_inverse,
+            PrimValue return_counts, Optional<PrimValue> axis) {
   static const Op& op = Op::Get("relax.unique");
-  return Call(op, {std::move(x)}, Attrs(attrs), {});
+  Call call;
+  if (!axis) {
+    call = Call(op, {std::move(x), sorted, return_index, return_inverse, 
return_counts});
+  } else {
+    PrimValue pv_axis = axis.value();
+    call = Call(op, {std::move(x), sorted, return_index, return_inverse, 
return_counts, pv_axis});
+  }
+  return call;
 }
 
 TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique);
 
 StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) {
-  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
-  const auto* attrs = call->attrs.as<UniqueAttrs>();
-  if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+  TensorStructInfo data_sinfo = 
Downcast<TensorStructInfo>(call->args[0]->struct_info_);
+  PrimValue axis, return_index, return_inverse, return_counts;
+  if (call->args.size() == 6) {
+    if (auto* prim_value_node = call->args[5].as<PrimValueNode>()) {
+      axis = GetRef<PrimValue>(prim_value_node);
+    }
+  }
+  if (!data_sinfo->IsUnknownNdim() && axis.defined()) {
     // Normalize the axis for sanity check purpose.
-    NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value);
+    if (const auto* axis_int = axis->value.as<IntImmNode>()) {
+      NormalizeAxis(call, ctx, data_sinfo->ndim, axis_int->value);
+    }
   }
-
-  int n_int_return = static_cast<int>(attrs->return_index) +
-                     static_cast<int>(attrs->return_inverse) +
-                     static_cast<int>(attrs->return_counts);
+  ICHECK(call->args[2]->IsInstance<PrimValueNode>());
+  ICHECK(call->args[3]->IsInstance<PrimValueNode>());
+  ICHECK(call->args[4]->IsInstance<PrimValueNode>());
+
+  return_index = Downcast<PrimValue>(call->args[2]);
+  return_inverse = Downcast<PrimValue>(call->args[3]);
+  return_counts = Downcast<PrimValue>(call->args[4]);
+
+  auto f_convert_to_int64 = [](const PrimExpr& value) {
+    CHECK(value->IsInstance<IntImmNode>())
+        << value << " expects to be IntImm, but gets " << value->GetTypeKey();
+    const auto* val_node = value.as<IntImmNode>();
+    auto val_imm = GetRef<IntImm>(val_node);
+    return val_imm->value;
+  };
+
+  int64_t n_int_return = f_convert_to_int64(return_index->value) +
+                         f_convert_to_int64(return_inverse->value) +
+                         f_convert_to_int64(return_counts->value);
 
   std::vector<StructInfo> output_sinfo;
   output_sinfo.reserve(1 + n_int_return);
@@ -67,7 +88,7 @@ StructInfo InferStructInfoUnique(const Call& call, const 
BlockBuilder& ctx) {
   if (data_sinfo->ndim == 0) {
     output_sinfo.push_back(
         TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), 
data_sinfo->dtype));
-  } else if (attrs->axis.defined()) {
+  } else if (axis.defined()) {
     output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, 
data_sinfo->ndim));
   } else {
     output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
@@ -93,9 +114,24 @@ StructInfo InferStructInfoUnique(const Call& call, const 
BlockBuilder& ctx) {
 }
 
 TVM_REGISTER_OP("relax.unique")
-    .set_attrs_type<UniqueAttrs>()
-    .set_num_inputs(1)
+    .set_num_inputs(6)
     .add_argument("x", "Tensor", "The input tensor")
+    .add_argument(
+        "sorted", "Tensor",
+        "Whether to sort the unique elements in ascending order before 
returning as output.")
+    .add_argument(
+        "return_index", "Tensor",
+        "Whether to return an additional tensor with indices for where 
elements in the unique "
+        "tensor come from the original input.")
+    .add_argument("return_inverse", "Tensor",
+                  "Whether to return an additional tensor with indices for 
where elements in the "
+                  "original input ended up in the returned unique list.")
+    .add_argument("return_counts", "Tensor",
+                  "Whether to return an additional tensor with counts of each 
unique elements")
+    .add_argument(
+        "axis", "Tensor",
+        "The dimension to apply unique. If it is NullOpt, the unique values of 
the flattened input "
+        "are returned.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoUnique)
     .set_attr<FCallPacked>("FCallPacked", "relax.run.unique");
 
diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h
index 83d2619e4d..a5c7ee85bf 100644
--- a/src/relax/op/tensor/set.h
+++ b/src/relax/op/tensor/set.h
@@ -24,16 +24,13 @@
 #ifndef TVM_RELAX_OP_TENSOR_SET_H_
 #define TVM_RELAX_OP_TENSOR_SET_H_
 
-#include <tvm/relax/attrs/set.h>
-
 #include "../op_common.h"
 
 namespace tvm {
 namespace relax {
 
-Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool 
return_counts,
-            Optional<Integer> axis);
-
+Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue 
return_inverse,
+            PrimValue return_counts, Optional<PrimValue> axis);
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index 7b0b98fea9..c66a5729fd 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -26,13 +26,128 @@ from tvm._ffi.base import TVMError
 from tvm.script import relax as R
 
 
[email protected]_module
+class InputModule:
+    @R.function
+    def foo(x: R.Tensor(("m", "n"), "int64")):
+        y = R.unique(x, sorted=False)
+        y_sorted = R.unique(x)
+        return y, y_sorted
+
+
 def run_cpu(mod, func_name, *input):
     target = tvm.target.Target("llvm")
-    ex = relax.vm.build(mod, target)
+    ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
     return vm[func_name](*input)
 
 
+def test_unique():
+
+    # TODO(prakalp): also add test for compiling and running on cuda device.
+    data_numpy = np.random.randint(0, 16, (16, 16))
+    data = tvm.nd.array(data_numpy)
+    result, result_sorted = run_cpu(InputModule, "foo", data)
+
+    expected_output_sorted, indices = np.unique(data_numpy, return_index=True)
+    expected_output = [data_numpy.flatten()[index] for index in 
sorted(indices, reverse=True)]
+
+    np.testing.assert_array_equal(expected_output_sorted, 
result_sorted.numpy())
+    np.testing.assert_array_equal(expected_output, result.numpy())
+
+
[email protected]_module
+class PrintTest:
+    @R.function
+    def foo(x: R.Tensor((), "int32")):
+        # results have to be bound, but we don't use them
+        # TODO: We should allow calls whose results are not bound for side 
effects;
+        #       it would be easy syntactic sugar to add.
+        p1 = R.print(x)
+        p2 = R.print(x, format="Number: {}")
+        t = (x, x)
+        p3 = R.print(t, format="Tuple: {}")
+        p4 = R.print(x, t)
+        p5 = R.print(x, x, format="Custom print: {} {}")
+        p6 = R.print(x, t, format="Another print: {} {}")
+        return x
+
+
+def test_print():
+    try:
+        stdout = sys.stdout
+        with tempfile.TemporaryFile(mode="w+") as test_out:
+            sys.stdout = test_out
+            run_cpu(PrintTest, "foo", 
tvm.nd.array(np.array(1).astype("int32")))
+            test_out.seek(0)
+            printed_text = str(test_out.read())
+            expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 
1\nAnother print: 1 (1, 1)\n"
+            assert printed_text in expected, ("printed_text is ", printed_text)
+    finally:
+        sys.stdout = stdout
+
+
[email protected]_module
+class AssertOpTest:
+    @R.function
+    def passes(x: R.Tensor((), "int32")):
+        p1 = R.assert_op(relax.const(True))
+        return x
+
+    @R.function
+    def pass_with_args(x: R.Tensor((), "int32")):
+        p1 = R.assert_op(relax.const(True), x, format="You won't see me")
+        return x
+
+    @R.function
+    def simple_fail(x: R.Tensor((), "int32")):
+        p1 = R.assert_op(relax.const(False))
+        return x
+
+    @R.function
+    def fail_with_message(x: R.Tensor((), "int32")):
+        p1 = R.assert_op(relax.const(False), format="I failed...")
+        return x
+
+    @R.function
+    def fail_with_args(x: R.Tensor((), "int32")):
+        # no format
+        p1 = R.assert_op(relax.const(False), [x, x])
+        return x
+
+    @R.function
+    def fail_with_formatted_message(x: R.Tensor((), "int32")):
+        p1 = R.assert_op(relax.const(False), x, format="Number: {}")
+        return x
+
+
+def test_assert_op():
+    def check_assertion_error(func_name, func_arg, expected_message):
+        passed = False
+        try:
+            run_cpu(AssertOpTest, func_name, func_arg)
+            passed = True
+        except TVMError as e:
+            # TVM will print out a TVMError that will contain the
+            # generated error at the bottom of a stack trace
+            assert "AssertionError" in e.args[0]
+            assert expected_message in e.args[0]
+        assert not passed
+
+    run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32")))
+    run_cpu(AssertOpTest, "pass_with_args", 
tvm.nd.array(np.array(2).astype("int32")))
+    check_assertion_error(
+        "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion 
Failed"
+    )
+    check_assertion_error(
+        "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I 
failed..."
+    )
+    check_assertion_error("fail_with_args", 
tvm.nd.array(np.array(5).astype("int32")), "5, 5")
+    check_assertion_error(
+        "fail_with_formatted_message", 
tvm.nd.array(np.array(6).astype("int32")), "Number: 6"
+    )
+
+
 @tvm.script.ir_module
 class ShapeOfTest:
     @R.function
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index b458b290ec..7724c8e761 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -213,8 +213,8 @@ def test_relax_base_op():
         alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), 
"float32", 0))
         shape = bb.emit(relax.op.shape_of(alloc))
         bb.emit_func_output(shape)
-    # todo(yongwww): comment this check because 0 was changed to 
R.prim_value(0) in the printed IR
-    # _check(foo, bb.get()["foo"])
+
+    _check(foo, bb.get()["foo"])
 
 
 def test_symbolic_shape():

Reply via email to