This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 77a319e9262f711bb97429c1da881a9a6e69069b
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 14:55:30 2023 -0500

    [Unity] Relax op: set (#13990)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the set operators.
    
    Co-authored-by: Prakalp Srivastava <[email protected]>
---
 include/tvm/relax/attrs/set.h                      |  62 ++
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/relax/op/set.py                         | 101 +++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/tensor/set.cc                         | 103 +++
 src/relax/op/tensor/set.h                          |  40 +
 tests/python/relax/test_op_set.py                  | 862 +++++++++++++++++++++
 tests/python/relax/test_tvmscript_parser_op_set.py |  68 ++
 9 files changed, 1244 insertions(+)

diff --git a/include/tvm/relax/attrs/set.h b/include/tvm/relax/attrs/set.h
new file mode 100644
index 0000000000..3fae7646ff
--- /dev/null
+++ b/include/tvm/relax/attrs/set.h
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/set.h
+ * \brief Attributes for set operators.
+ */
+#ifndef TVM_RELAX_ATTRS_SET_H_
+#define TVM_RELAX_ATTRS_SET_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in unique operator */
+struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
+  bool sorted;
+  bool return_index;
+  bool return_inverse;
+  bool return_counts;
+  Optional<Integer> axis;
+
+  TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") {
+    TVM_ATTR_FIELD(sorted).describe(
+        "Whether to sort the unique elements in ascending order before 
returning as output.");
+    TVM_ATTR_FIELD(return_index)
+        .describe(
+            "Whether to return an additional tensor with indices for where 
elements in the unique "
+            "tensor come from the original input.");
+    TVM_ATTR_FIELD(return_inverse)
+        .describe(
+            "Whether to return an additional tensor with indices for where 
elements in the "
+            "original input ended up in the returned unique list.");
+    TVM_ATTR_FIELD(return_counts)
+        .describe("Whether to return an additional tensor with counts of each 
unique elements");
+    TVM_ATTR_FIELD(axis).describe(
+        "The dimension to apply unique. If it is NullOpt, the unique values of 
the flattened input "
+        "is are returned.");
+  }
+};  // struct UniqueAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_SET_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index f3ab9085b8..da29c3715d 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -24,5 +24,6 @@ from .datatype import *
 from .index import *
 from .manipulate import *
 from .op_attrs import *
+from .set import *
 from . import builtin
 from . import memory
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index cb33363944..47c3b28798 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -32,3 +32,8 @@ class TakeAttrs(Attrs):
 @tvm._ffi.register_object("relax.attrs.StridedSliceAttrs")
 class StridedSliceAttrs(Attrs):
     """Attributes used in strided_slice operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
+class UniqueAttrs(Attrs):
+    """Attributes used for the unique operator"""
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
new file mode 100644
index 0000000000..b7ee0f3811
--- /dev/null
+++ b/python/tvm/relax/op/set.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument
+"""Set operators."""
+from typing import Optional
+
+import numpy as np  # type: ignore
+import tvm
+
+from . import _ffi_api
+from ..expr import Expr
+
+
+def unique(
+    x: Expr,
+    sorted: bool = True,
+    return_index: bool = False,
+    return_inverse: bool = False,
+    return_counts: bool = False,
+    axis: Optional[int] = None,
+) -> Expr:
+    """Find the unique elements in a given tensor.
+    In addition, it optionally returns
+    - the indices of the input tensor that give the unique values;
+    - the indices of the unique tensor that reconstruct the input tensor;
+    - the number of times each unique value comes up in the input tensor.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input tensor.
+
+    sorted : bool
+        Whether to sort the unique elements in ascending order before
+        returning as output.
+
+    return_index : bool
+        Whether to return an additional tensor with indices for where elements 
in
+        the unique tensor come from the original input.
+
+    return_inverse : bool
+        Whether to return an additional tensor with indices for where elements 
in
+        the original input ended up in the returned unique list.
+
+    return_counts : bool
+        Whether to return an additional tensor with counts of each unique 
elements.
+
+    axis : Optional
+        The dimension to apply unique.
+        If not specified, the unique values of the flattened input are 
returned.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The created relax call with
+    """
+
+    return _ffi_api.unique(  # type: ignore
+        x, sorted, return_index, return_inverse, return_counts, axis
+    )
+
+
[email protected]_func("relax.run.unique")
+def numpy_unique(
+    x: tvm.nd.array,
+    sorted: int,
+    return_index: int,
+    return_inverse: int,
+    return_counts: int,
+    axis: Optional[int],
+) -> tvm.nd.array:
+    """Returns the unique elements of the input tensor.
+
+    Uses numpy.unique to compute unique elements.
+    """
+    import builtins
+
+    # TODO(prakalp): add support for returning a tuple when return_inverse or 
return_counts is True
+    if bool(return_index) or bool(return_inverse) or bool(return_counts):
+        raise NotImplementedError("missing support return_inverse or 
return_counts set to true")
+    x_numpy = x.numpy()
+    # TODO(prakalp): use torch.unique instead of numpy when torch is installed 
in ci.
+    output_sorted_numpy, indices = np.unique(x_numpy, return_index=True)
+    if sorted:
+        return tvm.nd.array(output_sorted_numpy)
+    output_numpy = [x_numpy.flatten()[index] for index in 
builtins.sorted(indices, reverse=True)]
+    return tvm.nd.array(output_numpy)
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index aaee0f4e2f..537adec615 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -45,6 +45,7 @@ from tvm.relax.op import (
     shape_of,
     strided_slice,
     take,
+    unique,
 )
 from tvm.relax.struct_info import StructInfo
 from tvm.relax.utils import args_converter
@@ -434,4 +435,5 @@ __all__ = [
     "strided_slice",
     "take",
     "tuple",
+    "unique",
 ]
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
new file mode 100644
index 0000000000..4d5a274e17
--- /dev/null
+++ b/src/relax/op/tensor/set.cc
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file set.cc
+ * \brief Relax set operators.
+ */
+
+#include "set.h"
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.unique */
+TVM_REGISTER_NODE_TYPE(UniqueAttrs);
+
+Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool 
return_counts,
+            Optional<Integer> axis) {
+  ObjectPtr<UniqueAttrs> attrs = make_object<UniqueAttrs>();
+  attrs->sorted = sorted;
+  attrs->return_index = return_index;
+  attrs->return_inverse = return_inverse;
+  attrs->return_counts = return_counts;
+  attrs->axis = std::move(axis);
+
+  static const Op& op = Op::Get("relax.unique");
+  return Call(op, {std::move(x)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique);
+
+StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<UniqueAttrs>();
+  if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+    // Normalize the axis for sanity check purpose.
+    NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value);
+  }
+
+  int n_int_return = static_cast<int>(attrs->return_index) +
+                     static_cast<int>(attrs->return_inverse) +
+                     static_cast<int>(attrs->return_counts);
+
+  std::vector<StructInfo> output_sinfo;
+  output_sinfo.reserve(1 + n_int_return);
+
+  // unique values
+  if (data_sinfo->ndim == 0) {
+    output_sinfo.push_back(
+        TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), 
data_sinfo->dtype));
+  } else if (attrs->axis.defined()) {
+    output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, 
data_sinfo->ndim));
+  } else {
+    output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
+  }
+
+  // index, reverse and counts
+  TensorStructInfo int_return{nullptr};
+  if (data_sinfo->ndim == 0) {
+    int_return =
+        TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), 
DataType::Int(64));
+  } else {
+    int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1);
+  }
+  for (int i = 0; i < n_int_return; ++i) {
+    output_sinfo.push_back(int_return);
+  }
+
+  if (output_sinfo.size() == 1) {
+    return output_sinfo[0];
+  } else {
+    return TupleStructInfo(output_sinfo);
+  }
+}
+
+TVM_REGISTER_OP("relax.unique")
+    .set_attrs_type<UniqueAttrs>()
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor", "The input tensor")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoUnique)
+    .set_attr<FCallPacked>("FCallPacked", "relax.run.unique");
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h
new file mode 100644
index 0000000000..83d2619e4d
--- /dev/null
+++ b/src/relax/op/tensor/set.h
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  Sex The NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  Sex The License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file set.h
+ * \brief The functions to make Relax set operator calls.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_SET_H_
+#define TVM_RELAX_OP_TENSOR_SET_H_
+
+#include <tvm/relax/attrs/set.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool 
return_counts,
+            Optional<Integer> axis);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_SET_H_
diff --git a/tests/python/relax/test_op_set.py 
b/tests/python/relax/test_op_set.py
new file mode 100644
index 0000000000..755d5e8f87
--- /dev/null
+++ b/tests/python/relax/test_op_set.py
@@ -0,0 +1,862 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    assert relax.op.unique(x).op == Op.get("relax.unique")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_unique_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, return_index=False, return_inverse=False, return_counts=True, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, return_index=False, return_inverse=True, return_counts=False, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=True, 
return_counts=False, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, return_index=True, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=False, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=False, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=False, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=-2),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, sorted=True, return_index=True, return_inverse=True, 
return_counts=True, axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, sorted=True, return_index=True, return_inverse=True, 
return_counts=True, axis=1
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x1, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x1, return_index=False, return_inverse=True, return_counts=False, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=False, return_inverse=True, 
return_counts=False, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=False, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x2, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x2, return_index=True, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=True, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo(dtype="float32"), 
relax.TensorStructInfo(dtype="int64", ndim=1)]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=True, return_inverse=True, 
return_counts=False, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=True, return_inverse=True, 
return_counts=False, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x3, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x3, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x3, return_index=False, return_inverse=False, return_counts=True, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x3, return_index=False, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x3, return_index=False, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x3, return_index=False, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x3, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x3, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+
+
+def test_unique_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=False, return_inverse=False, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=False, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=False, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=False, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+
+
+def test_unique_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4)))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x0, return_index=False, return_inverse=False, return_counts=True, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=False, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x1, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=False, return_inverse=False, 
return_counts=False, axis=1),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x1, return_index=False, return_inverse=False, return_counts=True, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=False, return_inverse=False, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo(dtype="float32"), 
relax.TensorStructInfo(dtype="int64", ndim=1)]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=False, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=False, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=True, 
return_counts=True, axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+
+
+def test_unique_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32"))
+
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float16", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="int8", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x2, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="int32", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+                relax.TensorStructInfo(dtype="int64", ndim=1),
+            ]
+        ),
+    )
+
+
+def test_unique_infer_struct_info_input_zero_rank():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(()))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0))
+    x0 = relax.Var("x", R.Tensor((), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=0))
+    x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.unique(x0, return_index=True, return_inverse=True, 
return_counts=True, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((1,), "float32"),
+                relax.TensorStructInfo((1,), "int64"),
+                relax.TensorStructInfo((1,), "int64"),
+                relax.TensorStructInfo((1,), "int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(x1, return_index=True, return_inverse=True, 
return_counts=False, axis=None),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((1,), "float32"),
+                relax.TensorStructInfo((1,), "int64"),
+                relax.TensorStructInfo((1,), "int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x2, return_index=True, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((1,), "float32"), 
relax.TensorStructInfo((1,), "int64")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.unique(
+            x3, return_index=False, return_inverse=False, return_counts=False, 
axis=None
+        ),
+        relax.TensorStructInfo((1,), "float32"),
+    )
+
+
+def test_unique_infer_struct_info_axis_out_of_range():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor((), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.unique(x0, axis=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.unique(x0, axis=-4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.unique(x1, axis=0))
+
+
+def test_unique_infer_struct_info_wrong_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.unique(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.unique(x1))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_set.py 
b/tests/python/relax/test_tvmscript_parser_op_set.py
new file mode 100644
index 0000000000..8e01fa6f62
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_set.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_unique():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 4), dtype="float32")
+    ) -> R.Tuple(
+        R.Tensor(dtype="float32", ndim=3),
+        R.Tensor(dtype="int64", ndim=1),
+        R.Tensor(dtype="int64", ndim=1),
+    ):
+        gv: R.Tuple(
+            R.Tensor(dtype="float32", ndim=3),
+            R.Tensor(dtype="int64", ndim=1),
+            R.Tensor(dtype="int64", ndim=1),
+        ) = R.unique(
+            x, sorted=True, return_index=False, return_inverse=True, 
return_counts=True, axis=1
+        )
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(
+            relax.op.unique(x, sorted=True, return_inverse=True, 
return_counts=True, axis=1)
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to