This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f866ae3edd [Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035)
f866ae3edd is described below
commit f866ae3edd9d77985612979e5c1936e93b56a740
Author: Yixin Dong <[email protected]>
AuthorDate: Sat Feb 18 21:14:34 2023 +0800
[Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035)
This PR specially checks the relax.reshape operator when the input is a
ShapeExpr.
---
tests/python/relax/test_op_manipulate.py | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 92d4bb2676..6c7727b7d5 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -54,6 +54,7 @@ def test_reshape_infer_struct_info():
s0 = relax.Var("s", R.Shape((3, 8, 5)))
s1 = relax.Var("s", R.Shape(ndim=3))
s2 = relax.Var("s", R.Shape())
+ s3 = relax.ShapeExpr((3, 8, 5))
_check_inference(
bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5),
"float32")
@@ -98,6 +99,12 @@ def test_reshape_infer_struct_info():
_check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2,
dtype=""))
_check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2,
dtype=""))
_check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2,
dtype=""))
+ _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3,
"float32"))
+ _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3,
"float32"))
+ _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3,
"float32"))
+ _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3,
dtype=""))
+ _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3,
dtype=""))
+ _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3,
dtype=""))
def test_reshape_infer_struct_info_shape_symbolic():
@@ -109,6 +116,7 @@ def test_reshape_infer_struct_info_shape_symbolic():
x = relax.Var("x", R.Tensor((a, b, c, d), "float32"))
s0 = relax.Var("s", R.Shape((c, a, d, b)))
s1 = relax.Var("s", R.Shape())
+ s2 = relax.ShapeExpr((c, a, d, b))
_check_inference(
bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a,
d, b), "float32")
@@ -147,6 +155,7 @@ def test_reshape_infer_struct_info_shape_symbolic():
)
_check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0,
"float32"))
_check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1,
"float32"))
+ _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2,
"float32"))
def test_reshape_infer_struct_info_shape_var():