This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c7e9292c9e [Relax] Update ONNX frontend for unique, nonzero and 
compress (#17511)
c7e9292c9e is described below

commit c7e9292c9eabbe5bcfe79466ed12a75a5c3e2f4f
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Nov 12 23:55:02 2024 +0800

    [Relax] Update ONNX frontend for unique, nonzero and compress (#17511)
    
    This PR updates the ONNX frontend:
    
    - Add match cast for unique and nonzero operators, enabling further import 
of ONNX models.
    - Add support for compress operator.
    - Fix the shape of the output tensor for nonzero operator.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 52 +++++++++++++++++++++++--
 python/tvm/relax/op/set.py                      |  2 +-
 src/relax/op/tensor/set.cc                      |  4 +-
 tests/python/relax/test_frontend_onnx.py        | 30 +++++++++++++-
 tests/python/relax/test_op_set.py               |  2 +-
 5 files changed, 81 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index eb7a3eaf36..94ccfdb23e 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -833,6 +833,32 @@ class ScatterND(OnnxOpConverter):
         return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)
 
 
+class Compress(OnnxOpConverter):
+    """Convert an onnx Compress node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        tensor, condition = inputs
+        axis = attr.get("axis", None)
+
+        # Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4]
+        if condition.struct_info.dtype != "bool":
+            raise ValueError("Condition tensor is expected to be a boolean 
tensor")
+        if condition.struct_info.ndim != 1:
+            raise ValueError("Condition tensor is expected to be a 1D boolean 
tensor")
+        indices = relax.op.nonzero(condition)
+        num_nonzero = tir.Var("num_nonzero", "int64")
+        indices = bb.match_cast(indices, relax.TensorStructInfo([1, 
num_nonzero], "int64"))
+        indices = relax.op.reshape(indices, [-1])
+
+        if axis is not None:
+            return relax.op.take(tensor, indices, axis=axis)
+
+        # if axis is None, flatten input tensor before selection
+        tensor = relax.op.reshape(tensor, (-1,))
+        return relax.op.take(tensor, indices, axis=0)
+
+
 class Size(OnnxOpConverter):
     """Convert an onnx Size node into an equivalent Relax expression."""
 
@@ -2726,7 +2752,22 @@ class Unique(OnnxOpConverter):
         axis = attr.get("axis", None)
         sorted = bool(attr.get("sorted", 1))
         # TODO(tvm-team): Add support for return_index, return_inverse, 
return_counts
-        return relax.op.unique(data, sorted=sorted, axis=axis)
+        unique = relax.op.unique(data, sorted=sorted, axis=axis)
+        unique_numbers = tir.Var("unique_numbers", "int64")
+        input_shape = data.struct_info.shape
+        dtype = data.struct_info.dtype
+
+        if axis is None:
+            # flatten the input tensor
+            return bb.match_cast(unique, 
relax.TensorStructInfo((unique_numbers,), dtype))
+
+        axis = axis if axis >= 0 else len(input_shape) + axis
+        if axis < 0 or axis >= len(input_shape):
+            raise ValueError(f"Axis {axis} is out of bounds")
+        output_shape = [
+            input_shape[i] if i != axis else unique_numbers for i in 
range(len(input_shape))
+        ]
+        return bb.match_cast(unique, relax.TensorStructInfo(output_shape, 
dtype))
 
 
 class NonZero(OnnxOpConverter):
@@ -2734,7 +2775,12 @@ class NonZero(OnnxOpConverter):
 
     @classmethod
     def _impl_v9(cls, bb, inputs, attr, params):
-        return relax.op.nonzero(inputs[0])
+        ndim = inputs[0].struct_info.ndim
+        ndim = 1 if ndim == 0 else ndim
+        nonzero_numbers = tir.Var("nonzero_numbers", "int64")
+        return bb.match_cast(
+            relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim, 
nonzero_numbers), "int64")
+        )
 
 
 class HardSigmoid(OnnxOpConverter):
@@ -3075,7 +3121,7 @@ def _get_convert_map():
         "Scatter": Scatter,
         "ScatterElements": ScatterElements,
         "ScatterND": ScatterND,
-        # "Compress": Compress,
+        "Compress": Compress,
         "Size": Size,
         "EyeLike": EyeLike,
         # Normalization
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index c5db852ddd..ed4b2e2ff9 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -123,7 +123,7 @@ def nonzero(x: Expr) -> Expr:
     Returns
     -------
     result : relax.Expr
-        A (n+1)-D tensor containing indices of non-zero elements.
+        A 2-D tensor containing indices of non-zero elements.
 
     Note
     ----
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index c659a49afd..e2aef8005e 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -148,9 +148,7 @@ 
TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero);
 
 StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) {
   TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
-  // Cheat zero dim scalar as 1-dim.
-  int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, 
data_sinfo->ndim) + 1;
-  return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice);
+  return TensorStructInfo(DataType::Int(64), 2, data_sinfo->vdevice);
 }
 
 TVM_REGISTER_OP("relax.nonzero")
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 6f74957a07..a4a4f78bd3 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -601,6 +601,34 @@ def test_scatter_nd(reduction):
     verify_scatter_nd([10], [5, 1], [5])
 
 
[email protected]("tensor_shape", [[32, 32]])
[email protected]("condition_shape", [None, [8], [16]])
[email protected]("axis", [None, 0, 1])
+def test_compress(
+    tensor_shape: List[int],
+    condition_shape: Optional[List[int]],
+    axis: Optional[int],
+):
+    if condition_shape is None and axis is None:
+        pytest.skip("Either condition_shape or axis must be specified")
+    if condition_shape is None:
+        condition_shape = [tensor_shape[axis]]
+    compress_node = helper.make_node("Compress", ["tensor", "condition"], 
["output"], axis=axis)
+    graph = helper.make_graph(
+        [compress_node],
+        "compress_test",
+        inputs=[
+            helper.make_tensor_value_info("tensor", TensorProto.FLOAT, 
tensor_shape),
+            helper.make_tensor_value_info("condition", TensorProto.BOOL, 
condition_shape),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("output", TensorProto.FLOAT, [])
+        ],  # shape is unknown
+    )
+    model = helper.make_model(graph, producer_name="compress_test")
+    check_correctness(model, opset=11)
+
+
 def test_size():
     test_node = helper.make_node("Size", ["x"], ["y"])
     graph = helper.make_graph(
@@ -2478,7 +2506,7 @@ def test_unique(axis: Optional[int], sorted: int):
     check_correctness(model)
 
 
[email protected]("shape", [(), (1,), (2, 3), (4, 5, 6)])
[email protected]("shape", [(), (1,), (2, 3), (4, 5, 6), (7, 8, 9, 10)])
 def test_nonzero(shape):
     verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, 
output_dtype=TensorProto.INT64)
 
diff --git a/tests/python/relax/test_op_set.py 
b/tests/python/relax/test_op_set.py
index e9070f99fc..05b6d8887b 100644
--- a/tests/python/relax/test_op_set.py
+++ b/tests/python/relax/test_op_set.py
@@ -875,7 +875,7 @@ def test_nonzero_infer_struct_info(shape):
     _check_inference(
         bb,
         relax.op.nonzero(x0),
-        relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"),
+        relax.TensorStructInfo(ndim=2, dtype="int64"),
     )
 
 

Reply via email to