This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new be3d42b8a7 [RELAX][ONNX] Add support for dynamic shape expression in
Expand (#17504)
be3d42b8a7 is described below
commit be3d42b8a70e118b39ce1f8a5add0d99df477270
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Tue Nov 5 01:30:07 2024 +0100
[RELAX][ONNX] Add support for dynamic shape expression in Expand (#17504)
* updated expand to support dynamic relax.ShapeExpr
updated slice to convert PrimExpr to PrimValue before sending values to
relax.op.strided_slice
* added test for dynamic shape expression in test_expand
* updated formatting
removed unnecessary list comprehension
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 35 +++++++++++++++--
tests/python/relax/test_frontend_onnx.py | 52 ++++++++++++++++++-------
2 files changed, 68 insertions(+), 19 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 611f4348d5..cbd633324a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -909,7 +909,7 @@ class Trilu(OnnxOpConverter):
if len(inputs) > 1:
k = get_constant(inputs[1], params)
if isinstance(k, relax.Constant):
- k = int(k.data.numpy()[0])
+ k = int(k.data.numpy().item())
else:
raise ValueError("Currently only support constant k for Trilu
op.")
else:
@@ -1588,6 +1588,16 @@ class Split(OnnxOpConverter):
return bb.emit_te(topi.split, inputs[0], indices,
axis=attr.get("axis", 0))
+def get_prim_value_list(values):
+ new_values = []
+ for v in list(values):
+ if isinstance(v, relax.expr.PrimExpr):
+ new_values.append(relax.PrimValue(v))
+ else:
+ new_values.append(v)
+ return new_values
+
+
class Slice(OnnxOpConverter):
"""Converts an onnx Splice node into an equivalent Relax expression."""
@@ -1641,7 +1651,12 @@ class Slice(OnnxOpConverter):
assume_inbound = not all(
[isinstance(param, (tir.IntImm, int)) for param in [*starts,
*ends, *steps]]
)
- # return relax.op.strided_slice(data, axes, starts, ends, steps)
+
+ # Converting PrimExpr to PrimValue since relax.op.strided_slice does
not accept PrimExpr
+ starts = get_prim_value_list(starts)
+ ends = get_prim_value_list(ends)
+ steps = get_prim_value_list(steps)
+
return relax.op.strided_slice(
data, axes, starts, ends, steps, assume_inbound=assume_inbound
)
@@ -1730,9 +1745,21 @@ class Expand(OnnxOpConverter):
def _impl_v13(cls, bb, inputs, attr, params):
data = inputs[0]
shape = inputs[1]
-
if isinstance(shape, relax.ShapeExpr):
- return relax.op.broadcast_to(data, shape)
+ data_shape = list(data.struct_info.shape)
+ target_shape = list(shape.values)
+ data_shape = [1] * (len(target_shape) - len(data_shape)) +
data_shape
+ assert len(data_shape) == len(target_shape)
+ # Fix small target shapes or target shapes assigned to -1
+ for i, s in enumerate(target_shape):
+ if isinstance(s, tvm.tir.IntImm) and (
+ (isinstance(data_shape[i], tvm.tir.IntImm) and s <
data_shape[i])
+ or s.value == -1
+ ):
+ target_shape[i] = data_shape[i]
+ if target_shape == data_shape:
+ return data
+ return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape))
# If possible, directly expand to constant shape.
if isinstance(shape, relax.Constant):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 9faa441138..c130bf4373 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1507,10 +1507,6 @@ def test_topk(axis: int, largest: int):
@pytest.mark.parametrize("dynamic", [False, True])
def test_expand(dynamic):
- if dynamic:
- # TODO: Support dynamic shape for Expand
- pytest.skip("Dynamic expand is not supported yet")
-
def _test_expand(name, data, shape, ref_data):
shape_array = np.array(shape)
shape_node = onnx.helper.make_node(
@@ -1541,17 +1537,43 @@ def test_expand(dynamic):
model = helper.make_model(graph, producer_name=name)
check_correctness(model, inputs={"in": data})
- in_shape = (3, 1)
- shape = (3, 4)
- data = np.random.uniform(size=in_shape).astype(np.float32)
- ref_data = np.tile(data, 4)
- _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
-
- in_shape = (3, 1)
- shape = (1, 3, 4)
- data = np.random.uniform(size=in_shape).astype(np.float32)
- ref_data = np.tile(data, (1, 1, 4))
- _test_expand("expand_with_diff_dim", data, shape, ref_data)
+ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape,
ref_data):
+ shape_node = onnx.helper.make_node("Shape", inputs=["in_2"],
outputs=["shape"])
+ expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
+ in_shape = list(data.shape)
+ out_shape = list(ref_data.shape)
+ graph = helper.make_graph(
+ [shape_node, expand_node],
+ "expand_test",
+ inputs=[
+ helper.make_tensor_value_info("in", TensorProto.FLOAT,
in_shape),
+ helper.make_tensor_value_info("in_2", TensorProto.FLOAT,
shape),
+ ],
+ outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name=name)
+ check_correctness(model, inputs={"in": data, "in_2": shape_data})
+
+ if not dynamic:
+ in_shape = (3, 1)
+ shape = (3, 4)
+ data = np.random.uniform(size=in_shape).astype(np.float32)
+ ref_data = np.tile(data, 4)
+ _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
+
+ in_shape = (3, 1)
+ shape = (1, 3, 4)
+ data = np.random.uniform(size=in_shape).astype(np.float32)
+ ref_data = np.tile(data, (1, 1, 4))
+ _test_expand("expand_with_diff_dim", data, shape, ref_data)
+ else:
+ in_shape = (1, 32, 32)
+ shape = ("batch", 32, 32)
+ data = np.random.uniform(size=in_shape).astype(np.float32)
+ shape_data = np.random.uniform(size=(64, 32, 32)).astype(np.float32)
+ ref_data = np.tile(data, (64, 1, 1))
+ _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data,
shape_data, shape, ref_data)
# TODO(jwfromm) Current approach to dynamic expand is technically not well
formed. Reenable once fixed.