This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new be3d42b8a7 [RELAX][ONNX] Add support for dynamic shape expression in 
Expand (#17504)
be3d42b8a7 is described below

commit be3d42b8a70e118b39ce1f8a5add0d99df477270
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Tue Nov 5 01:30:07 2024 +0100

    [RELAX][ONNX] Add support for dynamic shape expression in Expand (#17504)
    
    * updated expand to support dynamic relax.ShapeExpr
    
    updated slice to convert PrimExpr to PrimValue before sending values to
    relax.op.strided_slice
    
    * added test for dynamic shape expression in test_expand
    
    * updated formatting
    
    removed unnecessary list comprehension
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 35 +++++++++++++++--
 tests/python/relax/test_frontend_onnx.py        | 52 ++++++++++++++++++-------
 2 files changed, 68 insertions(+), 19 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 611f4348d5..cbd633324a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -909,7 +909,7 @@ class Trilu(OnnxOpConverter):
         if len(inputs) > 1:
             k = get_constant(inputs[1], params)
             if isinstance(k, relax.Constant):
-                k = int(k.data.numpy()[0])
+                k = int(k.data.numpy().item())
             else:
                 raise ValueError("Currently only support constant k for Trilu 
op.")
         else:
@@ -1588,6 +1588,16 @@ class Split(OnnxOpConverter):
         return bb.emit_te(topi.split, inputs[0], indices, 
axis=attr.get("axis", 0))
 
 
+def get_prim_value_list(values):
+    new_values = []
+    for v in list(values):
+        if isinstance(v, relax.expr.PrimExpr):
+            new_values.append(relax.PrimValue(v))
+        else:
+            new_values.append(v)
+    return new_values
+
+
 class Slice(OnnxOpConverter):
     """Converts an onnx Splice node into an equivalent Relax expression."""
 
@@ -1641,7 +1651,12 @@ class Slice(OnnxOpConverter):
         assume_inbound = not all(
             [isinstance(param, (tir.IntImm, int)) for param in [*starts, 
*ends, *steps]]
         )
-        # return relax.op.strided_slice(data, axes, starts, ends, steps)
+
+        # Converting PrimExpr to PrimValue since relax.op.strided_slice does 
not accept PrimExpr
+        starts = get_prim_value_list(starts)
+        ends = get_prim_value_list(ends)
+        steps = get_prim_value_list(steps)
+
         return relax.op.strided_slice(
             data, axes, starts, ends, steps, assume_inbound=assume_inbound
         )
@@ -1730,9 +1745,21 @@ class Expand(OnnxOpConverter):
     def _impl_v13(cls, bb, inputs, attr, params):
         data = inputs[0]
         shape = inputs[1]
-
         if isinstance(shape, relax.ShapeExpr):
-            return relax.op.broadcast_to(data, shape)
+            data_shape = list(data.struct_info.shape)
+            target_shape = list(shape.values)
+            data_shape = [1] * (len(target_shape) - len(data_shape)) + 
data_shape
+            assert len(data_shape) == len(target_shape)
+            # Fix small target shapes or target shapes assigned to -1
+            for i, s in enumerate(target_shape):
+                if isinstance(s, tvm.tir.IntImm) and (
+                    (isinstance(data_shape[i], tvm.tir.IntImm) and s < 
data_shape[i])
+                    or s.value == -1
+                ):
+                    target_shape[i] = data_shape[i]
+            if target_shape == data_shape:
+                return data
+            return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape))
 
         # If possible, directly expand to constant shape.
         if isinstance(shape, relax.Constant):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 9faa441138..c130bf4373 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1507,10 +1507,6 @@ def test_topk(axis: int, largest: int):
 
 @pytest.mark.parametrize("dynamic", [False, True])
 def test_expand(dynamic):
-    if dynamic:
-        # TODO: Support dynamic shape for Expand
-        pytest.skip("Dynamic expand is not supported yet")
-
     def _test_expand(name, data, shape, ref_data):
         shape_array = np.array(shape)
         shape_node = onnx.helper.make_node(
@@ -1541,17 +1537,43 @@ def test_expand(dynamic):
         model = helper.make_model(graph, producer_name=name)
         check_correctness(model, inputs={"in": data})
 
-    in_shape = (3, 1)
-    shape = (3, 4)
-    data = np.random.uniform(size=in_shape).astype(np.float32)
-    ref_data = np.tile(data, 4)
-    _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
-
-    in_shape = (3, 1)
-    shape = (1, 3, 4)
-    data = np.random.uniform(size=in_shape).astype(np.float32)
-    ref_data = np.tile(data, (1, 1, 4))
-    _test_expand("expand_with_diff_dim", data, shape, ref_data)
+    def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, 
ref_data):
+        shape_node = onnx.helper.make_node("Shape", inputs=["in_2"], 
outputs=["shape"])
+        expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
+        in_shape = list(data.shape)
+        out_shape = list(ref_data.shape)
+        graph = helper.make_graph(
+            [shape_node, expand_node],
+            "expand_test",
+            inputs=[
+                helper.make_tensor_value_info("in", TensorProto.FLOAT, 
in_shape),
+                helper.make_tensor_value_info("in_2", TensorProto.FLOAT, 
shape),
+            ],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
out_shape)],
+        )
+
+        model = helper.make_model(graph, producer_name=name)
+        check_correctness(model, inputs={"in": data, "in_2": shape_data})
+
+    if not dynamic:
+        in_shape = (3, 1)
+        shape = (3, 4)
+        data = np.random.uniform(size=in_shape).astype(np.float32)
+        ref_data = np.tile(data, 4)
+        _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
+
+        in_shape = (3, 1)
+        shape = (1, 3, 4)
+        data = np.random.uniform(size=in_shape).astype(np.float32)
+        ref_data = np.tile(data, (1, 1, 4))
+        _test_expand("expand_with_diff_dim", data, shape, ref_data)
+    else:
+        in_shape = (1, 32, 32)
+        shape = ("batch", 32, 32)
+        data = np.random.uniform(size=in_shape).astype(np.float32)
+        shape_data = np.random.uniform(size=(64, 32, 32)).astype(np.float32)
+        ref_data = np.tile(data, (64, 1, 1))
+        _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, 
shape_data, shape, ref_data)
 
 
 # TODO(jwfromm) Current approach to dynamic expand is technically not well 
formed. Reenable once fixed.

Reply via email to