This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 361d21bbe9 [Relax][ONNX] add support for unique optional outputs 
(#18652)
361d21bbe9 is described below

commit 361d21bbe9e66bf7fbd8cd630ae49ff3278e176a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 20 10:38:18 2026 +0800

    [Relax][ONNX] add support for unique optional outputs (#18652)
    
    ## Why
    
    The ONNX Unique operator supports four optional outputs (unique values,
    indices, inverse_indices, and counts), but the TVM ONNX frontend only
    returned the unique values output.
    
    ## How
    
    - Updated `Unique._impl_v11` to check the number of expected outputs via
    `attr["tvm_custom"]["num_outputs"]`
    - Pass `return_index`, `return_inverse`, and `return_counts` parameters
    to `relax.op.unique`
    - Return a `relax.Tuple` containing all requested outputs
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 65 +++++++++++++++----
 python/tvm/relax/op/set.py                      | 85 ++++++++++++++++++++++---
 src/relax/op/tensor/set.cc                      | 43 ++++++++++---
 tests/python/relax/test_frontend_onnx.py        | 23 +++++--
 4 files changed, 180 insertions(+), 36 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 4dbb0ca36f..e14e2ed956 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3239,24 +3239,63 @@ class Unique(OnnxOpConverter):
     def _impl_v11(cls, bb, inputs, attr, params):
         data = inputs[0]
         axis = attr.get("axis", None)
-        sorted = bool(attr.get("sorted", 1))
-        # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
-        unique = relax.op.unique(data, sorted=sorted, axis=axis)
+        sorted_flag = bool(attr.get("sorted", 1))
+        num_outputs = attr["tvm_custom"]["num_outputs"]
+
+        return_index = num_outputs > 1
+        return_inverse = num_outputs > 2
+        return_counts = num_outputs > 3
+
+        unique = relax.op.unique(
+            data,
+            sorted=sorted_flag,
+            return_index=return_index,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            axis=axis,
+        )
+
         unique_numbers = tir.Var("unique_numbers", "int64")
         input_shape = data.struct_info.shape
         dtype = data.struct_info.dtype
 
         if axis is None:
-            # flatten the input tensor
-            return bb.match_cast(unique, 
relax.TensorStructInfo((unique_numbers,), dtype))
-
-        axis = axis if axis >= 0 else len(input_shape) + axis
-        if axis < 0 or axis >= len(input_shape):
-            raise ValueError(f"Axis {axis} is out of bounds")
-        output_shape = [
-            input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
-        ]
-        return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+            output_shape = (unique_numbers,)
+        else:
+            axis = axis if axis >= 0 else len(input_shape) + axis
+            if axis < 0 or axis >= len(input_shape):
+                raise ValueError(f"Axis {axis} is out of bounds")
+            output_shape = [
+                input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
+            ]
+
+        if num_outputs == 1:
+            return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+
+        outputs = [bb.match_cast(unique[0], 
relax.TensorStructInfo(output_shape, dtype))]
+        tuple_idx = 1  # Track which index in the tuple we're at
+
+        if return_index:
+            index_shape = (unique_numbers,)
+            index_sinfo = relax.TensorStructInfo(index_shape, "int64")
+            outputs.append(bb.match_cast(unique[tuple_idx], index_sinfo))
+            tuple_idx += 1
+
+        if return_inverse:
+            # ONNX spec: inverse_indices is always 1D
+            # When axis is None: shape is [X.size]
+            # When axis is specified: shape is [X.shape[axis]]
+            inverse_shape = (tir.Var("inverse_numbers", "int64"),)
+            inverse_sinfo = relax.TensorStructInfo(inverse_shape, "int64")
+            outputs.append(bb.match_cast(unique[tuple_idx], inverse_sinfo))
+            tuple_idx += 1
+
+        if return_counts:
+            count_shape = (unique_numbers,)
+            count_sinfo = relax.TensorStructInfo(count_shape, "int64")
+            outputs.append(bb.match_cast(unique[tuple_idx], count_sinfo))
+
+        return relax.Tuple(outputs)
 
 
 class NonZero(OnnxOpConverter):
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index 87fd067e5d..a7d837d673 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -99,17 +99,84 @@ def numpy_unique(
     """
     import builtins
 
-    # TODO(prakalp): add support for returning a tuple when return_inverse or 
return_counts is True
-    if bool(return_index) or bool(return_inverse) or bool(return_counts):
-        raise NotImplementedError("missing support return_inverse or 
return_counts set to true")
     x_numpy = x.numpy()
-    # TODO(prakalp): use torch.unique instead of numpy when torch is installed 
in ci.
-    output_sorted_numpy, indices = np.unique(x_numpy, return_index=True, 
axis=axis)
 
-    if sorted:
-        return tvm.runtime.tensor(output_sorted_numpy)
-    output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis)
-    return tvm.runtime.tensor(output_numpy)
+    # Call numpy.unique with all the requested return flags
+    result = np.unique(
+        x_numpy,
+        return_index=bool(return_index),
+        return_inverse=bool(return_inverse),
+        return_counts=bool(return_counts),
+        axis=axis,
+    )
+
+    # If no optional outputs requested, result is just the unique values
+    if not bool(return_index) and not bool(return_inverse) and not 
bool(return_counts):
+        unique_values = result
+        if not sorted:
+            indices = np.unique(x_numpy, return_index=True, axis=axis)[1]
+            unique_values = np.take(x_numpy, builtins.sorted(indices), 
axis=axis)
+        return tvm.runtime.tensor(unique_values)
+
+    # Otherwise, numpy returns a tuple
+    unique_values = result[0]
+    output_list = []
+    result_idx = 1
+
+    # Handle sorting for unique values
+    if not sorted and bool(return_index):
+        # Get the indices from numpy result
+        indices = result[result_idx]
+        result_idx += 1
+        # Sort indices to get original order
+        sort_order = np.argsort(indices)
+        unique_values = np.take(unique_values, sort_order, axis=axis)
+        indices = np.sort(indices)
+        output_list.append(tvm.runtime.tensor(unique_values))
+        output_list.append(tvm.runtime.tensor(indices))
+    elif not sorted:
+        # Need to get indices to reorder
+        _, indices = np.unique(x_numpy, return_index=True, axis=axis)
+        sort_order = np.argsort(indices)
+        unique_values = np.take(unique_values, sort_order, axis=axis)
+        output_list.append(tvm.runtime.tensor(unique_values))
+        if bool(return_index):
+            indices_from_result = result[result_idx]
+            result_idx += 1
+            
output_list.append(tvm.runtime.tensor(np.sort(indices_from_result)))
+    else:
+        # Sorted case
+        output_list.append(tvm.runtime.tensor(unique_values))
+        if bool(return_index):
+            output_list.append(tvm.runtime.tensor(result[result_idx]))
+            result_idx += 1
+
+    if bool(return_inverse):
+        inverse_indices = result[result_idx]
+        if not sorted:
+            # Need to remap inverse indices to match reordered unique values
+            _, orig_indices = np.unique(x_numpy, return_index=True, axis=axis)
+            sort_order = np.argsort(orig_indices)
+            inverse_mapping = np.empty_like(sort_order)
+            inverse_mapping[sort_order] = np.arange(len(sort_order))
+            inverse_indices = inverse_mapping[inverse_indices]
+        # ONNX spec: inverse_indices is always 1D
+        # When axis is None, it has length X.size (flattened)
+        # When axis is specified, it has length X.shape[axis]
+        # numpy.unique already returns 1D inverse_indices, so no reshaping 
needed
+        output_list.append(tvm.runtime.tensor(inverse_indices))
+        result_idx += 1
+
+    if bool(return_counts):
+        counts = result[result_idx]
+        if not sorted:
+            # Reorder counts to match reordered unique values
+            _, orig_indices = np.unique(x_numpy, return_index=True, axis=axis)
+            sort_order = np.argsort(orig_indices)
+            counts = counts[sort_order]
+        output_list.append(tvm.runtime.tensor(counts))
+
+    return tuple(output_list)
 
 
 def nonzero(x: Expr) -> Expr:
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index d80c73b131..c3ee496794 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -101,16 +101,41 @@ StructInfo InferStructInfoUnique(const Call& call, const 
BlockBuilder& ctx) {
     output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1, 
data_sinfo->vdevice));
   }
 
-  // index, reverse and counts
-  TensorStructInfo int_return{nullptr};
-  if (data_sinfo->ndim == 0) {
-    int_return = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), 
/*value=*/1)}),
-                                  DataType::Int(64), data_sinfo->vdevice);
-  } else {
-    int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1, 
data_sinfo->vdevice);
+  // index, inverse_indices, and counts
+  // index: always 1D
+  if (f_convert_to_int64(return_index->value)) {
+    TensorStructInfo index_sinfo{nullptr};
+    if (data_sinfo->ndim == 0) {
+      index_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), 
/*value=*/1)}),
+                                     DataType::Int(64), data_sinfo->vdevice);
+    } else {
+      index_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, 
data_sinfo->vdevice);
+    }
+    output_sinfo.push_back(index_sinfo);
+  }
+
+  // inverse_indices: always 1D per ONNX spec
+  if (f_convert_to_int64(return_inverse->value)) {
+    TensorStructInfo inverse_sinfo{nullptr};
+    if (data_sinfo->ndim == 0) {
+      inverse_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), 
/*value=*/1)}),
+                                       DataType::Int(64), data_sinfo->vdevice);
+    } else {
+      inverse_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, 
data_sinfo->vdevice);
+    }
+    output_sinfo.push_back(inverse_sinfo);
   }
-  for (int i = 0; i < n_int_return; ++i) {
-    output_sinfo.push_back(int_return);
+
+  // counts: always 1D
+  if (f_convert_to_int64(return_counts->value)) {
+    TensorStructInfo counts_sinfo{nullptr};
+    if (data_sinfo->ndim == 0) {
+      counts_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), 
/*value=*/1)}),
+                                      DataType::Int(64), data_sinfo->vdevice);
+    } else {
+      counts_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, 
data_sinfo->vdevice);
+    }
+    output_sinfo.push_back(counts_sinfo);
   }
 
   if (output_sinfo.size() == 1) {
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 6f5c7da5ef..df94c13478 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2922,19 +2922,32 @@ def test_onehot():
 
 @pytest.mark.parametrize("axis", [None, 0, 1, -1])
 @pytest.mark.parametrize("sorted", [0, 1])
-def test_unique(axis: Optional[int], sorted: int):
-    input_shape = [32, 32]
[email protected]("num_outputs", [1, 2, 3, 4])
+def test_unique(axis: Optional[int], sorted: int, num_outputs: int):
+    input_shape = [8, 8]
     if axis is None:
         output_shape = [-1]
     else:
-        output_shape = [32, 32]
+        output_shape = [8, 8]
         output_shape[axis] = -1
-    unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis, 
sorted=sorted)
+
+    output_names = ["y", "indices", "inverse_indices", "counts"][:num_outputs]
+    unique_node = helper.make_node("Unique", ["x"], output_names, axis=axis, 
sorted=sorted)
+
+    outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)]
+    if num_outputs > 1:
+        outputs.append(helper.make_tensor_value_info("indices", 
TensorProto.INT64, [-1]))
+    if num_outputs > 2:
+        # ONNX spec: inverse_indices is always 1D
+        outputs.append(helper.make_tensor_value_info("inverse_indices", 
TensorProto.INT64, [-1]))
+    if num_outputs > 3:
+        outputs.append(helper.make_tensor_value_info("counts", 
TensorProto.INT64, [-1]))
+
     graph = helper.make_graph(
         [unique_node],
         "unique_test",
         inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
input_shape)],
-        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)],
+        outputs=outputs,
     )
     model = helper.make_model(graph, producer_name="unique_test")
     check_correctness(model)

Reply via email to