slyubomirsky commented on code in PR #16355:
URL: https://github.com/apache/tvm/pull/16355#discussion_r1473749036
##########
tests/python/relax/test_op_manipulate.py:
##########
@@ -2176,6 +2176,110 @@ def test_split_indices_or_sections_int64():
assert split1.attrs.indices_or_sections.dtype == "int64"
+def test_split_infer_struct_info():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor((16, 4)))
+ y = relax.Var("y", R.Tensor((16, 4), "float32"))
+ z = relax.Var("z", R.Tensor((n, 16)))
+ w = relax.Var("w", R.Tensor((n + 5, 16)))
+
+ _check_inference(
+ bb,
+ relax.op.split(x, 1),
+ R.Tuple(
+ R.Tensor([16, 4]),
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.split(x, 2),
+ R.Tuple(
+ R.Tensor([8, 4]),
+ R.Tensor([8, 4]),
+ ),
+ )
+ # Uneven splits are allowed, with the last split being smaller than the
others.
+ _check_inference(
+ bb,
+ relax.op.split(x, 3),
+ R.Tuple(
+ R.Tensor([6, 4]),
+ R.Tensor([6, 4]),
+ R.Tensor([4, 4]),
+ ),
+ )
+
+ # Dtype of result is inherited from the tensor
+ _check_inference(
+ bb,
+ relax.op.split(y, 2),
+ R.Tuple(
+ R.Tensor([8, 4], "float32"),
+ R.Tensor([8, 4], "float32"),
+ ),
+ )
+
+ # Axis can be explicitly specified. Otherwise, defaults to axis=0.
+ _check_inference(
+ bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]),
R.Tensor([16, 2]))
+ )
+
+ # Split points can be explicitly specified
+ _check_inference(
+ bb,
+ relax.op.split(x, [2]),
+ R.Tuple(
+ R.Tensor([2, 4]),
+ R.Tensor([14, 4]),
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.split(x, [2, 5]),
+ R.Tuple(
+ R.Tensor([2, 4]),
+ R.Tensor([3, 4]),
+ R.Tensor([11, 4]),
+ ),
+ )
+
+ # Splitting a dynamic axis is allowed, and propagates the shape to the
output
+ _check_inference(
+ bb,
+ relax.op.split(z, 2),
+ R.Tuple(
+ R.Tensor([(n + 1) // 2, 16]),
+ R.Tensor([n - (n + 1) // 2, 16]),
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.split(z, 3),
+ R.Tuple(
+ R.Tensor([(n + 2) // 3, 16]),
+ R.Tensor([(n + 2) // 3, 16]),
+ R.Tensor([n - (n + 2) // 3 * 2, 16]),
+ ),
+ )
+
+ # Spliting a dynamic axis at specific indices is allowed. The
+ # algebraic form here isn't the cleanest, primarily because the
+ # test case doesn't know that `n` is a shape variable. When
+ # occurring in a relax function, `n` would be marked with
+ # `analyzer_.MarkGlobalNonNegValue`, which would make the shapes
+ # simplify to `[(2,16), (3,16), (n,16)]`.
Review Comment:
Also note the typo in the first word :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]