This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c7e9292c9e [Relax] Update ONNX frontend for unique, nonzero and
compress (#17511)
c7e9292c9e is described below
commit c7e9292c9eabbe5bcfe79466ed12a75a5c3e2f4f
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Nov 12 23:55:02 2024 +0800
[Relax] Update ONNX frontend for unique, nonzero and compress (#17511)
This PR updates the ONNX frontend:
- Add match cast for unique and nonzero operators, enabling further import
of ONNX models.
- Add support for compress operator.
- Fix the shape of the output tensor for nonzero operator.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 52 +++++++++++++++++++++++--
python/tvm/relax/op/set.py | 2 +-
src/relax/op/tensor/set.cc | 4 +-
tests/python/relax/test_frontend_onnx.py | 30 +++++++++++++-
tests/python/relax/test_op_set.py | 2 +-
5 files changed, 81 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index eb7a3eaf36..94ccfdb23e 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -833,6 +833,32 @@ class ScatterND(OnnxOpConverter):
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)
+class Compress(OnnxOpConverter):
+ """Convert an onnx Compress node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v11(cls, bb, inputs, attr, params):
+ tensor, condition = inputs
+ axis = attr.get("axis", None)
+
+ # Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4]
+ if condition.struct_info.dtype != "bool":
+ raise ValueError("Condition tensor is expected to be a boolean
tensor")
+ if condition.struct_info.ndim != 1:
+ raise ValueError("Condition tensor is expected to be a 1D boolean
tensor")
+ indices = relax.op.nonzero(condition)
+ num_nonzero = tir.Var("num_nonzero", "int64")
+ indices = bb.match_cast(indices, relax.TensorStructInfo([1,
num_nonzero], "int64"))
+ indices = relax.op.reshape(indices, [-1])
+
+ if axis is not None:
+ return relax.op.take(tensor, indices, axis=axis)
+
+ # if axis is None, flatten input tensor before selection
+ tensor = relax.op.reshape(tensor, (-1,))
+ return relax.op.take(tensor, indices, axis=0)
+
+
class Size(OnnxOpConverter):
"""Convert an onnx Size node into an equivalent Relax expression."""
@@ -2726,7 +2752,22 @@ class Unique(OnnxOpConverter):
axis = attr.get("axis", None)
sorted = bool(attr.get("sorted", 1))
# TODO(tvm-team): Add support for return_index, return_inverse,
return_counts
- return relax.op.unique(data, sorted=sorted, axis=axis)
+ unique = relax.op.unique(data, sorted=sorted, axis=axis)
+ unique_numbers = tir.Var("unique_numbers", "int64")
+ input_shape = data.struct_info.shape
+ dtype = data.struct_info.dtype
+
+ if axis is None:
+ # flatten the input tensor
+ return bb.match_cast(unique,
relax.TensorStructInfo((unique_numbers,), dtype))
+
+ axis = axis if axis >= 0 else len(input_shape) + axis
+ if axis < 0 or axis >= len(input_shape):
+ raise ValueError(f"Axis {axis} is out of bounds")
+ output_shape = [
+ input_shape[i] if i != axis else unique_numbers for i in
range(len(input_shape))
+ ]
+ return bb.match_cast(unique, relax.TensorStructInfo(output_shape,
dtype))
class NonZero(OnnxOpConverter):
@@ -2734,7 +2775,12 @@ class NonZero(OnnxOpConverter):
@classmethod
def _impl_v9(cls, bb, inputs, attr, params):
- return relax.op.nonzero(inputs[0])
+ ndim = inputs[0].struct_info.ndim
+ ndim = 1 if ndim == 0 else ndim
+ nonzero_numbers = tir.Var("nonzero_numbers", "int64")
+ return bb.match_cast(
+ relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim,
nonzero_numbers), "int64")
+ )
class HardSigmoid(OnnxOpConverter):
@@ -3075,7 +3121,7 @@ def _get_convert_map():
"Scatter": Scatter,
"ScatterElements": ScatterElements,
"ScatterND": ScatterND,
- # "Compress": Compress,
+ "Compress": Compress,
"Size": Size,
"EyeLike": EyeLike,
# Normalization
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index c5db852ddd..ed4b2e2ff9 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -123,7 +123,7 @@ def nonzero(x: Expr) -> Expr:
Returns
-------
result : relax.Expr
- A (n+1)-D tensor containing indices of non-zero elements.
+ A 2-D tensor containing indices of non-zero elements.
Note
----
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index c659a49afd..e2aef8005e 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -148,9 +148,7 @@
TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero);
StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
- // Cheat zero dim scalar as 1-dim.
- int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1,
data_sinfo->ndim) + 1;
- return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice);
+ return TensorStructInfo(DataType::Int(64), 2, data_sinfo->vdevice);
}
TVM_REGISTER_OP("relax.nonzero")
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 6f74957a07..a4a4f78bd3 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -601,6 +601,34 @@ def test_scatter_nd(reduction):
verify_scatter_nd([10], [5, 1], [5])
[email protected]("tensor_shape", [[32, 32]])
[email protected]("condition_shape", [None, [8], [16]])
[email protected]("axis", [None, 0, 1])
+def test_compress(
+ tensor_shape: List[int],
+ condition_shape: Optional[List[int]],
+ axis: Optional[int],
+):
+ if condition_shape is None and axis is None:
+ pytest.skip("Either condition_shape or axis must be specified")
+ if condition_shape is None:
+ condition_shape = [tensor_shape[axis]]
+ compress_node = helper.make_node("Compress", ["tensor", "condition"],
["output"], axis=axis)
+ graph = helper.make_graph(
+ [compress_node],
+ "compress_test",
+ inputs=[
+ helper.make_tensor_value_info("tensor", TensorProto.FLOAT,
tensor_shape),
+ helper.make_tensor_value_info("condition", TensorProto.BOOL,
condition_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT, [])
+ ], # shape is unknown
+ )
+ model = helper.make_model(graph, producer_name="compress_test")
+ check_correctness(model, opset=11)
+
+
def test_size():
test_node = helper.make_node("Size", ["x"], ["y"])
graph = helper.make_graph(
@@ -2478,7 +2506,7 @@ def test_unique(axis: Optional[int], sorted: int):
check_correctness(model)
[email protected]("shape", [(), (1,), (2, 3), (4, 5, 6)])
[email protected]("shape", [(), (1,), (2, 3), (4, 5, 6), (7, 8, 9, 10)])
def test_nonzero(shape):
verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL,
output_dtype=TensorProto.INT64)
diff --git a/tests/python/relax/test_op_set.py
b/tests/python/relax/test_op_set.py
index e9070f99fc..05b6d8887b 100644
--- a/tests/python/relax/test_op_set.py
+++ b/tests/python/relax/test_op_set.py
@@ -875,7 +875,7 @@ def test_nonzero_infer_struct_info(shape):
_check_inference(
bb,
relax.op.nonzero(x0),
- relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"),
+ relax.TensorStructInfo(ndim=2, dtype="int64"),
)