This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit d4c8f9bb0b5b88c6a735e6ee00cea370e506ea99
Author: Yixin Dong <[email protected]>
AuthorDate: Sat Feb 18 21:14:34 2023 +0800

    [Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035)
    
    This PR specially checks the relax.reshape operator when the input is a 
ShapeExpr.
---
 tests/python/relax/test_op_manipulate.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 92d4bb2676..6c7727b7d5 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -54,6 +54,7 @@ def test_reshape_infer_struct_info():
     s0 = relax.Var("s", R.Shape((3, 8, 5)))
     s1 = relax.Var("s", R.Shape(ndim=3))
     s2 = relax.Var("s", R.Shape())
+    s3 = relax.ShapeExpr((3, 8, 5))
 
     _check_inference(
         bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), 
"float32")
@@ -98,6 +99,12 @@ def test_reshape_infer_struct_info():
     _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, 
dtype=""))
     _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, 
dtype=""))
     _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, 
dtype=""))
+    _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3, 
"float32"))
+    _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3, 
"float32"))
+    _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3, 
"float32"))
+    _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3, 
dtype=""))
+    _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3, 
dtype=""))
+    _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3, 
dtype=""))
 
 
 def test_reshape_infer_struct_info_shape_symbolic():
@@ -109,6 +116,7 @@ def test_reshape_infer_struct_info_shape_symbolic():
     x = relax.Var("x", R.Tensor((a, b, c, d), "float32"))
     s0 = relax.Var("s", R.Shape((c, a, d, b)))
     s1 = relax.Var("s", R.Shape())
+    s2 = relax.ShapeExpr((c, a, d, b))
 
     _check_inference(
         bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, 
d, b), "float32")
@@ -147,6 +155,7 @@ def test_reshape_infer_struct_info_shape_symbolic():
     )
     _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, 
"float32"))
     _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, 
"float32"))
+    _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, 
"float32"))
 
 
 def test_reshape_infer_struct_info_shape_var():

Reply via email to