This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cb6e4ee147 [Unity] Infer struct info for relax.op.split on 
dynamic-sized index (#16355)
cb6e4ee147 is described below

commit cb6e4ee147b09dae057348143b8058c696951209
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Feb 6 08:36:55 2024 -0600

    [Unity] Infer struct info for relax.op.split on dynamic-sized index (#16355)
    
    Prior to this commit, the `relax.op.split` did not provide a known
    shape if the split was performed over a dynamic-size axis.  This
    commit updates the shape inference to provide correct shapes in this
    case.  A test case is also added for `CombineParallelMatmul` to show
    the intended usage of this feature, to ensure that the split outputs
    from a combined matmul still have correct shape information.
---
 src/relax/ir/block_builder.cc                      |   6 +-
 src/relax/op/tensor/manipulate.cc                  |  50 +++++----
 tests/python/relax/test_op_manipulate.py           | 125 +++++++++++++++++++--
 .../test_transform_combine_parallel_matmul.py      |  47 ++++++++
 4 files changed, 200 insertions(+), 28 deletions(-)

diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index f74434bd74..b39beae740 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -170,13 +170,17 @@ class BlockBuilderImpl : public BlockBuilderNode {
         auto it = shape_var_map.find(shape_var);
         if (it == shape_var_map.end()) {
           shape_var_map.Set(shape_var, shape_expr);
+          // Expose the shape variable as non-negative, for purposes
+          // of shape inference.  In many cases, knowning that the
+          // shape variable is non-negative allows for simpler
+          // expressions for dynamic shapes.
+          analyzer_.MarkGlobalNonNegValue(shape_var);
         } else {
           const PrimExpr& old_shape_expr = (*it).second;
           CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr))
               << "Inconsistent shape var " << shape_var << " in scope: " << 
old_shape_expr << " vs "
               << shape_expr;
         }
-        shape_var_map.Set(shape_var, shape_expr);
       }
     }
     scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)}));
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 12342aecf2..ad2a812c82 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -846,41 +846,48 @@ StructInfo InferStructInfoSplit(const Call& call, const 
BlockBuilder& ctx) {
   int axis =
       data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, 
data_sinfo->ndim, attrs->axis);
 
-  if (const auto* p_indices = attrs->indices_or_sections.as<ArrayNode>()) {
+  if (auto opt_indices = attrs->indices_or_sections.as<Array<IntImm>>()) {
+    auto p_indices = opt_indices.value();
     // When there is not index, return the input tensor's struct info.
-    if (p_indices->size() == 0) {
+    if (p_indices.size() == 0) {
       return TupleStructInfo({data_sinfo});
     }
     // Fall back to unknown shape when the input tensor doesn't have ShapeExpr 
as shape.
     if (data_shape == nullptr) {
       return TupleStructInfo(Array<StructInfo>(
-          p_indices->size() + 1,
+          p_indices.size() + 1,
           TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, 
data_sinfo->vdevice)));
     }
 
     ICHECK_NE(axis, -1);
-    const auto* axis_length = data_shape->values[axis].as<IntImmNode>();
-    // Fall back to unknown shape when the input tensor shape at the given 
axis is symbolic.
-    if (axis_length == nullptr) {
-      return TupleStructInfo(Array<StructInfo>(
-          p_indices->size() + 1,
-          TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, 
data_sinfo->vdevice)));
-    }
 
-    // Only do output shape inference when all the indices and the total 
length are integers.
-    Array<IntImm> indices = GetRef<Array<IntImm>>(p_indices);
     IntImm zero(DataType::Int(64), /*value=*/0);
-    indices.insert(indices.begin(), zero);
-    indices.insert(indices.end(), Downcast<IntImm>(data_shape->values[axis]));
 
     std::vector<StructInfo> output_sinfo;
-    output_sinfo.reserve(indices.size() - 1);
-    for (int i = 0; i + 1 < static_cast<int>(indices.size()); ++i) {
-      PrimExpr l = tvm::max(zero, indices[i]);
-      PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]);
+    for (size_t i = 0; i < p_indices.size() + 1; i++) {
+      PrimExpr left;
+      if (i == 0) {
+        left = zero;
+      } else {
+        left = p_indices[i - 1];
+      }
+
+      PrimExpr right;
+      if (i < p_indices.size()) {
+        right = p_indices[i];
+      } else {
+        right = data_shape->values[axis];
+      }
+
+      left = tvm::min(tvm::max(left, 0), data_shape->values[axis]);
+      right = tvm::min(tvm::max(right, 0), data_shape->values[axis]);
+
+      PrimExpr split_dim = right - left;
+      split_dim = tvm::max(split_dim, 0);
+      split_dim = ctx->GetAnalyzer()->Simplify(split_dim);
 
       Array<PrimExpr> shape = data_shape->values;
-      shape.Set(axis, tvm::max(zero, r - l));
+      shape.Set(axis, split_dim);
       output_sinfo.push_back(
           TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, 
data_sinfo->vdevice));
     }
@@ -899,6 +906,7 @@ StructInfo InferStructInfoSplit(const Call& call, const 
BlockBuilder& ctx) {
     }
     ICHECK_NE(axis, -1);
     PrimExpr split_len = ceildiv(data_shape->values[axis], n_section);
+    split_len = ctx->GetAnalyzer()->Simplify(split_len);
 
     // Construct struct info for tensors except the last one.
     Array<PrimExpr> shape = data_shape->values;
@@ -907,7 +915,9 @@ StructInfo InferStructInfoSplit(const Call& call, const 
BlockBuilder& ctx) {
         n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, 
data_sinfo->vdevice));
 
     // Construct struct info for the last tensor.
-    shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1));
+    PrimExpr last_split_len = data_shape->values[axis] - split_len * 
(n_section - 1);
+    last_split_len = ctx->GetAnalyzer()->Simplify(last_split_len);
+    shape.Set(axis, last_split_len);
     output_sinfo.push_back(
         TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, 
data_sinfo->vdevice));
     return TupleStructInfo(output_sinfo);
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 6c0fbcf227..ddb92725d4 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -20,7 +20,7 @@ import tvm.testing
 from tvm import relax, tir
 from tvm import TVMError
 from tvm.ir import Op, VDevice
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
 
 
 def test_op_correctness():
@@ -1832,9 +1832,9 @@ def 
test_split_infer_struct_info_by_indices_shape_symbolic():
         relax.op.split(x, [10, 20], axis=1),
         relax.TupleStructInfo(
             [
-                relax.TensorStructInfo(dtype="float32", ndim=2),
-                relax.TensorStructInfo(dtype="float32", ndim=2),
-                relax.TensorStructInfo(dtype="float32", ndim=2),
+                relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b), 
0)], dtype="float32"),
+                relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b), 
0)], dtype="float32"),
+                relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"),
             ]
         ),
     )
@@ -1987,9 +1987,9 @@ def 
test_split_infer_struct_info_by_n_section_shape_symbolic():
         relax.op.split(x, 3, axis=1),
         relax.TupleStructInfo(
             [
-                relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
-                relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"),
-                relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), 
"float32"),
+                relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
+                relax.TensorStructInfo((a, (b + 2) // 3), "float32"),
+                relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"),
             ]
         ),
     )
@@ -2176,6 +2176,117 @@ def test_split_indices_or_sections_int64():
     assert split1.attrs.indices_or_sections.dtype == "int64"
 
 
+def test_split_infer_struct_info():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor((16, 4)))
+    y = relax.Var("y", R.Tensor((16, 4), "float32"))
+    z = relax.Var("z", R.Tensor((n, 16)))
+    w = relax.Var("w", R.Tensor((n + 5, 16)))
+
+    # All relax shape variables are non-negative.  When a scope
+    # begins, any TIR variables that are used as shape variables are
+    # declared to be non-negative `tvm.arith.Analyzer`.  Because
+    # `relax.op.split` clamps the indices to be within the bounds of
+    # the axis being split, simplifying with non-negative shape
+    # variables can result in much simpler shapes.
+    #
+    # For example, an axis of size `n`, split on the range from 2 to 5
+    # has size `T.max(T.min(5, n + 5) - T.min(2, n + 5), 0)`.  If it
+    # is known that `n >= 0`, then this simplifies down to `3`.
+    bb.begin_scope([x, y, z, w])
+
+    _check_inference(
+        bb,
+        relax.op.split(x, 1),
+        R.Tuple(
+            R.Tensor([16, 4]),
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.split(x, 2),
+        R.Tuple(
+            R.Tensor([8, 4]),
+            R.Tensor([8, 4]),
+        ),
+    )
+    # Uneven splits are allowed, with the last split being smaller than the 
others.
+    _check_inference(
+        bb,
+        relax.op.split(x, 3),
+        R.Tuple(
+            R.Tensor([6, 4]),
+            R.Tensor([6, 4]),
+            R.Tensor([4, 4]),
+        ),
+    )
+
+    # Dtype of result is inherited from the tensor
+    _check_inference(
+        bb,
+        relax.op.split(y, 2),
+        R.Tuple(
+            R.Tensor([8, 4], "float32"),
+            R.Tensor([8, 4], "float32"),
+        ),
+    )
+
+    # Axis can be explicitly specified.  Otherwise, defaults to axis=0.
+    _check_inference(
+        bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]), 
R.Tensor([16, 2]))
+    )
+
+    # Split points can be explicitly specified
+    _check_inference(
+        bb,
+        relax.op.split(x, [2]),
+        R.Tuple(
+            R.Tensor([2, 4]),
+            R.Tensor([14, 4]),
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.split(x, [2, 5]),
+        R.Tuple(
+            R.Tensor([2, 4]),
+            R.Tensor([3, 4]),
+            R.Tensor([11, 4]),
+        ),
+    )
+
+    # Splitting a dynamic axis is allowed, and propagates the shape to the 
output
+    _check_inference(
+        bb,
+        relax.op.split(z, 2),
+        R.Tuple(
+            R.Tensor([(n + 1) // 2, 16]),
+            R.Tensor([n - (n + 1) // 2, 16]),
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.split(z, 3),
+        R.Tuple(
+            R.Tensor([(n + 2) // 3, 16]),
+            R.Tensor([(n + 2) // 3, 16]),
+            R.Tensor([n - (n + 2) // 3 * 2, 16]),
+        ),
+    )
+
+    # Splitting a dynamic axis at specific indices is allowed.
+    _check_inference(
+        bb,
+        relax.op.split(w, [2, 5]),
+        R.Tuple(
+            R.Tensor((2, 16)),
+            R.Tensor((3, 16)),
+            R.Tensor((n, 16)),
+        ),
+    )
+
+
 def test_split_infer_struct_info_non_integer_indices():
     bb = relax.BlockBuilder()
     a = tir.Var("c", "int64")
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py 
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index b06eddd2bb..7e7f2328f3 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -525,5 +525,52 @@ def test_check():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_dynamic_rhs():
+    @R.function(private=True)
+    def before(
+        x: R.Tensor((2, 1024, 640), "float32"),
+        w0: R.Tensor((640, 640), "float32"),
+        w1: R.Tensor((640, "M"), "float32"),
+    ):
+        M = T.int64()
+        with R.dataflow():
+            lv0 = R.matmul(x, w0)
+            lv1 = R.matmul(x, w1)
+            out = (lv0, lv1)
+            R.output(out)
+        return out
+
+    @R.function(private=True)
+    def expected(
+        x: R.Tensor((2, 1024, 640), dtype="float32"),
+        w0: R.Tensor((640, 640), dtype="float32"),
+        w1: R.Tensor((640, "M"), dtype="float32"),
+    ) -> R.Tuple(
+        R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, "M"), 
dtype="float32")
+    ):
+        M = T.int64()
+        with R.dataflow():
+            lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w0, w1), 
axis=1)
+            lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul(
+                x, lv, out_dtype="float32"
+            )
+            lv2: R.Tuple(
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, M), dtype="float32"),
+            ) = R.split(lv1, indices_or_sections=[640], axis=2)
+            lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0]
+            lv1_1: R.Tensor((2, 1024, M), dtype="float32") = lv2[1]
+            out: R.Tuple(
+                R.Tensor((2, 1024, 640), dtype="float32"),
+                R.Tensor((2, 1024, M), dtype="float32"),
+            ) = (lv0, lv1_1)
+            R.output(out)
+        return out
+
+    after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
+
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to