This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cb6e4ee147 [Unity] Infer struct info for relax.op.split on
dynamic-sized index (#16355)
cb6e4ee147 is described below
commit cb6e4ee147b09dae057348143b8058c696951209
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Feb 6 08:36:55 2024 -0600
[Unity] Infer struct info for relax.op.split on dynamic-sized index (#16355)
Prior to this commit, the `relax.op.split` did not provide a known
shape if the split was performed over a dynamic-size axis. This
commit updates the shape inference to provide correct shapes in this
case. A test case is also added for `CombineParallelMatmul` to show
the intended usage of this feature, to ensure that the split outputs
from a combined matmul still have correct shape information.
---
src/relax/ir/block_builder.cc | 6 +-
src/relax/op/tensor/manipulate.cc | 50 +++++----
tests/python/relax/test_op_manipulate.py | 125 +++++++++++++++++++--
.../test_transform_combine_parallel_matmul.py | 47 ++++++++
4 files changed, 200 insertions(+), 28 deletions(-)
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index f74434bd74..b39beae740 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -170,13 +170,17 @@ class BlockBuilderImpl : public BlockBuilderNode {
auto it = shape_var_map.find(shape_var);
if (it == shape_var_map.end()) {
shape_var_map.Set(shape_var, shape_expr);
+ // Expose the shape variable as non-negative, for purposes
+ // of shape inference. In many cases, knowning that the
+ // shape variable is non-negative allows for simpler
+ // expressions for dynamic shapes.
+ analyzer_.MarkGlobalNonNegValue(shape_var);
} else {
const PrimExpr& old_shape_expr = (*it).second;
CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr))
<< "Inconsistent shape var " << shape_var << " in scope: " <<
old_shape_expr << " vs "
<< shape_expr;
}
- shape_var_map.Set(shape_var, shape_expr);
}
}
scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)}));
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 12342aecf2..ad2a812c82 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -846,41 +846,48 @@ StructInfo InferStructInfoSplit(const Call& call, const
BlockBuilder& ctx) {
int axis =
data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx,
data_sinfo->ndim, attrs->axis);
- if (const auto* p_indices = attrs->indices_or_sections.as<ArrayNode>()) {
+ if (auto opt_indices = attrs->indices_or_sections.as<Array<IntImm>>()) {
+ auto p_indices = opt_indices.value();
// When there is not index, return the input tensor's struct info.
- if (p_indices->size() == 0) {
+ if (p_indices.size() == 0) {
return TupleStructInfo({data_sinfo});
}
// Fall back to unknown shape when the input tensor doesn't have ShapeExpr
as shape.
if (data_shape == nullptr) {
return TupleStructInfo(Array<StructInfo>(
- p_indices->size() + 1,
+ p_indices.size() + 1,
TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice)));
}
ICHECK_NE(axis, -1);
- const auto* axis_length = data_shape->values[axis].as<IntImmNode>();
- // Fall back to unknown shape when the input tensor shape at the given
axis is symbolic.
- if (axis_length == nullptr) {
- return TupleStructInfo(Array<StructInfo>(
- p_indices->size() + 1,
- TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice)));
- }
- // Only do output shape inference when all the indices and the total
length are integers.
- Array<IntImm> indices = GetRef<Array<IntImm>>(p_indices);
IntImm zero(DataType::Int(64), /*value=*/0);
- indices.insert(indices.begin(), zero);
- indices.insert(indices.end(), Downcast<IntImm>(data_shape->values[axis]));
std::vector<StructInfo> output_sinfo;
- output_sinfo.reserve(indices.size() - 1);
- for (int i = 0; i + 1 < static_cast<int>(indices.size()); ++i) {
- PrimExpr l = tvm::max(zero, indices[i]);
- PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]);
+ for (size_t i = 0; i < p_indices.size() + 1; i++) {
+ PrimExpr left;
+ if (i == 0) {
+ left = zero;
+ } else {
+ left = p_indices[i - 1];
+ }
+
+ PrimExpr right;
+ if (i < p_indices.size()) {
+ right = p_indices[i];
+ } else {
+ right = data_shape->values[axis];
+ }
+
+ left = tvm::min(tvm::max(left, 0), data_shape->values[axis]);
+ right = tvm::min(tvm::max(right, 0), data_shape->values[axis]);
+
+ PrimExpr split_dim = right - left;
+ split_dim = tvm::max(split_dim, 0);
+ split_dim = ctx->GetAnalyzer()->Simplify(split_dim);
Array<PrimExpr> shape = data_shape->values;
- shape.Set(axis, tvm::max(zero, r - l));
+ shape.Set(axis, split_dim);
output_sinfo.push_back(
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype,
data_sinfo->vdevice));
}
@@ -899,6 +906,7 @@ StructInfo InferStructInfoSplit(const Call& call, const
BlockBuilder& ctx) {
}
ICHECK_NE(axis, -1);
PrimExpr split_len = ceildiv(data_shape->values[axis], n_section);
+ split_len = ctx->GetAnalyzer()->Simplify(split_len);
// Construct struct info for tensors except the last one.
Array<PrimExpr> shape = data_shape->values;
@@ -907,7 +915,9 @@ StructInfo InferStructInfoSplit(const Call& call, const
BlockBuilder& ctx) {
n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype,
data_sinfo->vdevice));
// Construct struct info for the last tensor.
- shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1));
+ PrimExpr last_split_len = data_shape->values[axis] - split_len *
(n_section - 1);
+ last_split_len = ctx->GetAnalyzer()->Simplify(last_split_len);
+ shape.Set(axis, last_split_len);
output_sinfo.push_back(
TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype,
data_sinfo->vdevice));
return TupleStructInfo(output_sinfo);
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 6c0fbcf227..ddb92725d4 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -20,7 +20,7 @@ import tvm.testing
from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op, VDevice
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
def test_op_correctness():
@@ -1832,9 +1832,9 @@ def
test_split_infer_struct_info_by_indices_shape_symbolic():
relax.op.split(x, [10, 20], axis=1),
relax.TupleStructInfo(
[
- relax.TensorStructInfo(dtype="float32", ndim=2),
- relax.TensorStructInfo(dtype="float32", ndim=2),
- relax.TensorStructInfo(dtype="float32", ndim=2),
+ relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b),
0)], dtype="float32"),
+ relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b),
0)], dtype="float32"),
+ relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"),
]
),
)
@@ -1987,9 +1987,9 @@ def
test_split_infer_struct_info_by_n_section_shape_symbolic():
relax.op.split(x, 3, axis=1),
relax.TupleStructInfo(
[
- relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
- relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
- relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2),
"float32"),
+ relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
+ relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
+ relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"),
]
),
)
@@ -2176,6 +2176,117 @@ def test_split_indices_or_sections_int64():
assert split1.attrs.indices_or_sections.dtype == "int64"
+def test_split_infer_struct_info():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ x = relax.Var("x", R.Tensor((16, 4)))
+ y = relax.Var("y", R.Tensor((16, 4), "float32"))
+ z = relax.Var("z", R.Tensor((n, 16)))
+ w = relax.Var("w", R.Tensor((n + 5, 16)))
+
+ # All relax shape variables are non-negative. When a scope
+ # begins, any TIR variables that are used as shape variables are
+ # declared to be non-negative `tvm.arith.Analyzer`. Because
+ # `relax.op.split` clamps the indices to be within the bounds of
+ # the axis being split, simplifying with non-negative shape
+ # variables can result in much simpler shapes.
+ #
+ # For example, an axis of size `n`, split on the range from 2 to 5
+ # has size `T.max(T.min(5, n + 5) - T.min(2, n + 5), 0)`. If it
+ # is known that `n >= 0`, then this simplifies down to `3`.
+ bb.begin_scope([x, y, z, w])
+
+ _check_inference(
+ bb,
+ relax.op.split(x, 1),
+ R.Tuple(
+ R.Tensor([16, 4]),
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.split(x, 2),
+ R.Tuple(
+ R.Tensor([8, 4]),
+ R.Tensor([8, 4]),
+ ),
+ )
+ # Uneven splits are allowed, with the last split being smaller than the
others.
+ _check_inference(
+ bb,
+ relax.op.split(x, 3),
+ R.Tuple(
+ R.Tensor([6, 4]),
+ R.Tensor([6, 4]),
+ R.Tensor([4, 4]),
+ ),
+ )
+
+ # Dtype of result is inherited from the tensor
+ _check_inference(
+ bb,
+ relax.op.split(y, 2),
+ R.Tuple(
+ R.Tensor([8, 4], "float32"),
+ R.Tensor([8, 4], "float32"),
+ ),
+ )
+
+ # Axis can be explicitly specified. Otherwise, defaults to axis=0.
+ _check_inference(
+ bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]),
R.Tensor([16, 2]))
+ )
+
+ # Split points can be explicitly specified
+ _check_inference(
+ bb,
+ relax.op.split(x, [2]),
+ R.Tuple(
+ R.Tensor([2, 4]),
+ R.Tensor([14, 4]),
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.split(x, [2, 5]),
+ R.Tuple(
+ R.Tensor([2, 4]),
+ R.Tensor([3, 4]),
+ R.Tensor([11, 4]),
+ ),
+ )
+
+ # Splitting a dynamic axis is allowed, and propagates the shape to the
output
+ _check_inference(
+ bb,
+ relax.op.split(z, 2),
+ R.Tuple(
+ R.Tensor([(n + 1) // 2, 16]),
+ R.Tensor([n - (n + 1) // 2, 16]),
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.split(z, 3),
+ R.Tuple(
+ R.Tensor([(n + 2) // 3, 16]),
+ R.Tensor([(n + 2) // 3, 16]),
+ R.Tensor([n - (n + 2) // 3 * 2, 16]),
+ ),
+ )
+
+ # Splitting a dynamic axis at specific indices is allowed.
+ _check_inference(
+ bb,
+ relax.op.split(w, [2, 5]),
+ R.Tuple(
+ R.Tensor((2, 16)),
+ R.Tensor((3, 16)),
+ R.Tensor((n, 16)),
+ ),
+ )
+
+
def test_split_infer_struct_info_non_integer_indices():
bb = relax.BlockBuilder()
a = tir.Var("c", "int64")
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index b06eddd2bb..7e7f2328f3 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -525,5 +525,52 @@ def test_check():
tvm.ir.assert_structural_equal(after, expected)
+def test_dynamic_rhs():
+ @R.function(private=True)
+ def before(
+ x: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, "M"), "float32"),
+ ):
+ M = T.int64()
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ out = (lv0, lv1)
+ R.output(out)
+ return out
+
+ @R.function(private=True)
+ def expected(
+ x: R.Tensor((2, 1024, 640), dtype="float32"),
+ w0: R.Tensor((640, 640), dtype="float32"),
+ w1: R.Tensor((640, "M"), dtype="float32"),
+ ) -> R.Tuple(
+ R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, "M"),
dtype="float32")
+ ):
+ M = T.int64()
+ with R.dataflow():
+ lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w0, w1),
axis=1)
+ lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
+ x, lv, out_dtype="float32"
+ )
+ lv2: R.Tuple(
+ R.Tensor((2, 1024, 640), dtype="float32"),
+ R.Tensor((2, 1024, M), dtype="float32"),
+ ) = R.split(lv1, indices_or_sections=[640], axis=2)
+ lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
+ lv1_1: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
+ out: R.Tuple(
+ R.Tensor((2, 1024, 640), dtype="float32"),
+ R.Tensor((2, 1024, M), dtype="float32"),
+ ) = (lv0, lv1_1)
+ R.output(out)
+ return out
+
+ after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
+
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()