gemini-code-assist[bot] commented on code in PR #18652:
URL: https://github.com/apache/tvm/pull/18652#discussion_r2678448140


##########
tests/python/relax/test_frontend_onnx.py:
##########
@@ -2884,19 +2884,32 @@ def test_onehot():
 
 @pytest.mark.parametrize("axis", [None, 0, 1, -1])
 @pytest.mark.parametrize("sorted", [0, 1])
-def test_unique(axis: Optional[int], sorted: int):
-    input_shape = [32, 32]
[email protected]("num_outputs", [1, 2, 3, 4])
+def test_unique(axis: Optional[int], sorted: int, num_outputs: int):
+    input_shape = [8, 8]
     if axis is None:
         output_shape = [-1]
     else:
-        output_shape = [32, 32]
+        output_shape = [8, 8]
         output_shape[axis] = -1
-    unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis, 
sorted=sorted)
+
+    output_names = ["y", "indices", "inverse_indices", "counts"][:num_outputs]
+    unique_node = helper.make_node("Unique", ["x"], output_names, axis=axis, 
sorted=sorted)
+
+    outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)]
+    if num_outputs > 1:
+        outputs.append(helper.make_tensor_value_info("indices", 
TensorProto.INT64, output_shape))
+    if num_outputs > 2:
+        inverse_shape = [-1] if axis is None else input_shape
+        outputs.append(helper.make_tensor_value_info("inverse_indices", 
TensorProto.INT64, inverse_shape))
+    if num_outputs > 3:
+        outputs.append(helper.make_tensor_value_info("counts", 
TensorProto.INT64, output_shape))

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The ONNX `Unique` operator's `counts` output is always a 1-D tensor. The 
shape should be `[-1]` to represent a 1-D tensor of dynamic length, not 
`output_shape` which can be multi-dimensional when `axis` is specified.
   
   ```suggestion
           outputs.append(helper.make_tensor_value_info("counts", 
TensorProto.INT64, [-1]))
   ```



##########
tests/python/relax/test_frontend_onnx.py:
##########
@@ -2884,19 +2884,32 @@ def test_onehot():
 
 @pytest.mark.parametrize("axis", [None, 0, 1, -1])
 @pytest.mark.parametrize("sorted", [0, 1])
-def test_unique(axis: Optional[int], sorted: int):
-    input_shape = [32, 32]
[email protected]("num_outputs", [1, 2, 3, 4])
+def test_unique(axis: Optional[int], sorted: int, num_outputs: int):
+    input_shape = [8, 8]
     if axis is None:
         output_shape = [-1]
     else:
-        output_shape = [32, 32]
+        output_shape = [8, 8]
         output_shape[axis] = -1
-    unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis, 
sorted=sorted)
+
+    output_names = ["y", "indices", "inverse_indices", "counts"][:num_outputs]
+    unique_node = helper.make_node("Unique", ["x"], output_names, axis=axis, 
sorted=sorted)
+
+    outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT, 
output_shape)]
+    if num_outputs > 1:
+        outputs.append(helper.make_tensor_value_info("indices", 
TensorProto.INT64, output_shape))

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The ONNX `Unique` operator's `indices` output is always a 1-D tensor. The 
shape should be `[-1]` to represent a 1-D tensor of dynamic length, not 
`output_shape` which can be multi-dimensional when `axis` is specified.
   
   ```suggestion
           outputs.append(helper.make_tensor_value_info("indices", 
TensorProto.INT64, [-1]))
   ```



##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3182,24 +3182,62 @@ class Unique(OnnxOpConverter):
     def _impl_v11(cls, bb, inputs, attr, params):
         data = inputs[0]
         axis = attr.get("axis", None)
-        sorted = bool(attr.get("sorted", 1))
-        # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
-        unique = relax.op.unique(data, sorted=sorted, axis=axis)
+        sorted_flag = bool(attr.get("sorted", 1))
+        num_outputs = attr["tvm_custom"]["num_outputs"]
+
+        return_index = num_outputs > 1
+        return_inverse = num_outputs > 2
+        return_counts = num_outputs > 3
+
+        unique = relax.op.unique(
+            data,
+            sorted=sorted_flag,
+            return_index=return_index,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            axis=axis
+        )
+
         unique_numbers = tir.Var("unique_numbers", "int64")
         input_shape = data.struct_info.shape
         dtype = data.struct_info.dtype
 
         if axis is None:
-            # flatten the input tensor
-            return bb.match_cast(unique, 
relax.TensorStructInfo((unique_numbers,), dtype))
-
-        axis = axis if axis >= 0 else len(input_shape) + axis
-        if axis < 0 or axis >= len(input_shape):
-            raise ValueError(f"Axis {axis} is out of bounds")
-        output_shape = [
-            input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
-        ]
-        return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+            output_shape = (unique_numbers,)
+        else:
+            axis = axis if axis >= 0 else len(input_shape) + axis
+            if axis < 0 or axis >= len(input_shape):
+                raise ValueError(f"Axis {axis} is out of bounds")
+            output_shape = [
+                input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
+            ]
+
+        if num_outputs == 1:
+            return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+
+        outputs = [bb.match_cast(unique[0], 
relax.TensorStructInfo(output_shape, dtype))]
+
+        if return_index:
+            index_shape = (unique_numbers,) if axis is None else output_shape
+            index_sinfo = relax.TensorStructInfo(index_shape, "int64")
+            outputs.append(bb.match_cast(unique[1], index_sinfo))
+
+        if return_inverse:
+            if axis is None:
+                inverse_shape = (tir.Var("inverse_numbers", "int64"),)
+            else:
+                inverse_shape = input_shape
+            idx = 2 if return_index else 1
+            inverse_sinfo = relax.TensorStructInfo(inverse_shape, "int64")
+            outputs.append(bb.match_cast(unique[idx], inverse_sinfo))
+
+        if return_counts:
+            count_shape = (unique_numbers,) if axis is None else output_shape

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   According to the ONNX `Unique` operator specification, the `counts` output 
is always a 1-D tensor of shape `(num_unique_values,)`, regardless of whether 
`axis` is specified. The current implementation incorrectly uses `output_shape` 
when `axis` is not None, which can be a multi-dimensional shape. This should be 
`(unique_numbers,)` in all cases.
   
   ```suggestion
               count_shape = (unique_numbers,)
   ```



##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3182,24 +3182,62 @@ class Unique(OnnxOpConverter):
     def _impl_v11(cls, bb, inputs, attr, params):
         data = inputs[0]
         axis = attr.get("axis", None)
-        sorted = bool(attr.get("sorted", 1))
-        # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
-        unique = relax.op.unique(data, sorted=sorted, axis=axis)
+        sorted_flag = bool(attr.get("sorted", 1))
+        num_outputs = attr["tvm_custom"]["num_outputs"]
+
+        return_index = num_outputs > 1
+        return_inverse = num_outputs > 2
+        return_counts = num_outputs > 3
+
+        unique = relax.op.unique(
+            data,
+            sorted=sorted_flag,
+            return_index=return_index,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            axis=axis
+        )
+
         unique_numbers = tir.Var("unique_numbers", "int64")
         input_shape = data.struct_info.shape
         dtype = data.struct_info.dtype
 
         if axis is None:
-            # flatten the input tensor
-            return bb.match_cast(unique, 
relax.TensorStructInfo((unique_numbers,), dtype))
-
-        axis = axis if axis >= 0 else len(input_shape) + axis
-        if axis < 0 or axis >= len(input_shape):
-            raise ValueError(f"Axis {axis} is out of bounds")
-        output_shape = [
-            input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
-        ]
-        return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+            output_shape = (unique_numbers,)
+        else:
+            axis = axis if axis >= 0 else len(input_shape) + axis
+            if axis < 0 or axis >= len(input_shape):
+                raise ValueError(f"Axis {axis} is out of bounds")
+            output_shape = [
+                input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
+            ]
+
+        if num_outputs == 1:
+            return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+
+        outputs = [bb.match_cast(unique[0], 
relax.TensorStructInfo(output_shape, dtype))]
+
+        if return_index:
+            index_shape = (unique_numbers,) if axis is None else output_shape
+            index_sinfo = relax.TensorStructInfo(index_shape, "int64")
+            outputs.append(bb.match_cast(unique[1], index_sinfo))
+
+        if return_inverse:
+            if axis is None:
+                inverse_shape = (tir.Var("inverse_numbers", "int64"),)
+            else:
+                inverse_shape = input_shape
+            idx = 2 if return_index else 1
+            inverse_sinfo = relax.TensorStructInfo(inverse_shape, "int64")
+            outputs.append(bb.match_cast(unique[idx], inverse_sinfo))
+
+        if return_counts:
+            count_shape = (unique_numbers,) if axis is None else output_shape
+            idx = (1 if return_index else 0) + (1 if return_inverse else 0) + 1

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Since `return_counts` is true only when `num_outputs > 3`, it implies both 
`return_index` and `return_inverse` are also true. Therefore, `idx` will always 
be 3 in this context. You can simplify this for clarity.
   
   ```suggestion
               idx = 3
   ```



##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3182,24 +3182,62 @@ class Unique(OnnxOpConverter):
     def _impl_v11(cls, bb, inputs, attr, params):
         data = inputs[0]
         axis = attr.get("axis", None)
-        sorted = bool(attr.get("sorted", 1))
-        # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
-        unique = relax.op.unique(data, sorted=sorted, axis=axis)
+        sorted_flag = bool(attr.get("sorted", 1))
+        num_outputs = attr["tvm_custom"]["num_outputs"]
+
+        return_index = num_outputs > 1
+        return_inverse = num_outputs > 2
+        return_counts = num_outputs > 3
+
+        unique = relax.op.unique(
+            data,
+            sorted=sorted_flag,
+            return_index=return_index,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            axis=axis
+        )
+
         unique_numbers = tir.Var("unique_numbers", "int64")
         input_shape = data.struct_info.shape
         dtype = data.struct_info.dtype
 
         if axis is None:
-            # flatten the input tensor
-            return bb.match_cast(unique, 
relax.TensorStructInfo((unique_numbers,), dtype))
-
-        axis = axis if axis >= 0 else len(input_shape) + axis
-        if axis < 0 or axis >= len(input_shape):
-            raise ValueError(f"Axis {axis} is out of bounds")
-        output_shape = [
-            input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
-        ]
-        return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+            output_shape = (unique_numbers,)
+        else:
+            axis = axis if axis >= 0 else len(input_shape) + axis
+            if axis < 0 or axis >= len(input_shape):
+                raise ValueError(f"Axis {axis} is out of bounds")
+            output_shape = [
+                input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
+            ]
+
+        if num_outputs == 1:
+            return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+
+        outputs = [bb.match_cast(unique[0], 
relax.TensorStructInfo(output_shape, dtype))]
+
+        if return_index:
+            index_shape = (unique_numbers,) if axis is None else output_shape
+            index_sinfo = relax.TensorStructInfo(index_shape, "int64")
+            outputs.append(bb.match_cast(unique[1], index_sinfo))
+
+        if return_inverse:
+            if axis is None:
+                inverse_shape = (tir.Var("inverse_numbers", "int64"),)
+            else:
+                inverse_shape = input_shape
+            idx = 2 if return_index else 1

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Since `return_inverse` is true only when `num_outputs > 2`, it implies 
`return_index` (`num_outputs > 1`) is also true. Therefore, `idx` will always 
be 2 in this context. You can simplify this for clarity.
   
   ```suggestion
               idx = 2
   ```



##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3182,24 +3182,62 @@ class Unique(OnnxOpConverter):
     def _impl_v11(cls, bb, inputs, attr, params):
         data = inputs[0]
         axis = attr.get("axis", None)
-        sorted = bool(attr.get("sorted", 1))
-        # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
-        unique = relax.op.unique(data, sorted=sorted, axis=axis)
+        sorted_flag = bool(attr.get("sorted", 1))
+        num_outputs = attr["tvm_custom"]["num_outputs"]
+
+        return_index = num_outputs > 1
+        return_inverse = num_outputs > 2
+        return_counts = num_outputs > 3
+
+        unique = relax.op.unique(
+            data,
+            sorted=sorted_flag,
+            return_index=return_index,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            axis=axis
+        )
+
         unique_numbers = tir.Var("unique_numbers", "int64")
         input_shape = data.struct_info.shape
         dtype = data.struct_info.dtype
 
         if axis is None:
-            # flatten the input tensor
-            return bb.match_cast(unique, 
relax.TensorStructInfo((unique_numbers,), dtype))
-
-        axis = axis if axis >= 0 else len(input_shape) + axis
-        if axis < 0 or axis >= len(input_shape):
-            raise ValueError(f"Axis {axis} is out of bounds")
-        output_shape = [
-            input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
-        ]
-        return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+            output_shape = (unique_numbers,)
+        else:
+            axis = axis if axis >= 0 else len(input_shape) + axis
+            if axis < 0 or axis >= len(input_shape):
+                raise ValueError(f"Axis {axis} is out of bounds")
+            output_shape = [
+                input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
+            ]
+
+        if num_outputs == 1:
+            return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
+
+        outputs = [bb.match_cast(unique[0], 
relax.TensorStructInfo(output_shape, dtype))]
+
+        if return_index:
+            index_shape = (unique_numbers,) if axis is None else output_shape

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   According to the ONNX `Unique` operator specification, the `indices` output 
is always a 1-D tensor of shape `(num_unique_values,)`, regardless of whether 
`axis` is specified. The current implementation incorrectly uses `output_shape` 
when `axis` is not None, which can be a multi-dimensional shape. This should be 
`(unique_numbers,)` in all cases.
   
   ```suggestion
               index_shape = (unique_numbers,)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to