This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit d4c8f9bb0b5b88c6a735e6ee00cea370e506ea99 Author: Yixin Dong <[email protected]> AuthorDate: Sat Feb 18 21:14:34 2023 +0800 [Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035) This PR specially checks the relax.reshape operator when the input is a ShapeExpr. --- tests/python/relax/test_op_manipulate.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 92d4bb2676..6c7727b7d5 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -54,6 +54,7 @@ def test_reshape_infer_struct_info(): s0 = relax.Var("s", R.Shape((3, 8, 5))) s1 = relax.Var("s", R.Shape(ndim=3)) s2 = relax.Var("s", R.Shape()) + s3 = relax.ShapeExpr((3, 8, 5)) _check_inference( bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") @@ -98,6 +99,12 @@ def test_reshape_infer_struct_info(): _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3, dtype="")) def test_reshape_infer_struct_info_shape_symbolic(): @@ -109,6 +116,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) s0 = relax.Var("s", R.Shape((c, a, d, b))) s1 = relax.Var("s", R.Shape()) + s2 = relax.ShapeExpr((c, a, d, b)) _check_inference( bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") @@ -147,6 +155,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): ) _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, "float32")) def test_reshape_infer_struct_info_shape_var():
