This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8b4dcfd1f1 [Unity][TOPI] Symbolic shape support for `collapse_sum` 
(#14535)
8b4dcfd1f1 is described below

commit 8b4dcfd1f13a7160cf186bf92ed22907c2589b0c
Author: Chaofan Lin <[email protected]>
AuthorDate: Tue Apr 11 12:00:39 2023 +0800

    [Unity][TOPI] Symbolic shape support for `collapse_sum` (#14535)
    
    This PR lets the `topi::collapse_sum` support symbolic shape cases. And as 
a result, in high-level op legalization, we can now legalize 
`R.collapse_sum_like / R.collapse_sum_to` with symbolic shapes.
---
 include/tvm/topi/reduction.h                       | 19 +++++--
 .../test_transform_legalize_ops_manipulate.py      | 64 +++++++++++++++++++++-
 tests/python/topi/python/test_topi_reduce.py       | 18 ++++--
 3 files changed, 89 insertions(+), 12 deletions(-)

diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h
index 169ae010aa..9090ae9707 100644
--- a/include/tvm/topi/reduction.h
+++ b/include/tvm/topi/reduction.h
@@ -333,21 +333,28 @@ inline Tensor sum(const Tensor& data, const 
Array<Integer>& axis, bool keepdims
 }
 
 inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
-  ICHECK_GE(data->shape.size(), target_shape.size());
-  auto ishape = detail::GetConstIntValues(data->shape, "ishape");
-  auto oshape = detail::GetConstIntValues(target_shape, "oshape");
+  const auto& ishape = data->shape;
+  const auto& oshape = target_shape;
+  int isize = data->shape.size();
+  int osize = target_shape.size();
+
+  ICHECK_GE(isize, osize)
+      << "Invalid collapse: input dimensionality smaller than output 
dimensionality.\ninput shape: "
+      << data->shape << "\nvs\noutput shape: " << target_shape;
 
   std::vector<int> reduce_axes;
   std::vector<int> squeeze_axes;
-  for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; 
--i_ax) {
-    if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
+  tvm::PrimExpr one(1);
+
+  for (int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) {
+    if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) {
       --o_ax;
       continue;
     }
     reduce_axes.push_back(i_ax);
     if (o_ax < 0) {  // squeeze o_ax if was added during expansion
       squeeze_axes.push_back(i_ax);
-    } else if (oshape[o_ax] == 1) {
+    } else if (topi::detail::EqualCheck(one, oshape[o_ax])) {
       --o_ax;
     }
   }
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 0f25f056b1..9be39183fd 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1013,7 +1013,6 @@ def test_collapse_sum_like():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
[email protected]("TOPI collapse_sum not support symbolic now")
 def test_collapse_sum_like_symbolic():
     # fmt: off
     @tvm.script.ir_module
@@ -1024,6 +1023,36 @@ def test_collapse_sum_like_symbolic():
             gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y)
             return gv
 
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def collapse_sum(var_rxplaceholder: T.handle, var_rxplaceholder_red: 
T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            a, b = T.int64(), T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, a))
+            rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, (b, 
T.int64(1)))
+            # with T.block("root"):
+            for ax0, ax1, k0, k2 in T.grid(b, T.int64(1), a, a):
+                with T.block("rxplaceholder_red"):
+                    v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, 
k0, k2])
+                    T.reads(rxplaceholder[v_k0, v_ax0, v_k2])
+                    T.writes(rxplaceholder_red[v_ax0, v_ax1])
+                    with T.init():
+                        rxplaceholder_red[v_ax0, v_ax1] = T.float32(0)
+                    rxplaceholder_red[v_ax0, v_ax1] = 
(rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2])
+
+        @R.function
+        def main(
+            x: R.Tensor(("a", "b", "a"), dtype="float32"),
+            y: R.Tensor(("b", 1), dtype="float32"),
+        ) -> R.Tensor(("b", 1), dtype="float32"):
+            b = T.int64()
+            a = T.int64()
+            cls = Expected
+            gv = R.call_tir(
+                cls.collapse_sum, (x,), out_sinfo=R.Tensor((b, 1), 
dtype="float32")
+            )
+            return gv
     # fmt: on
 
     mod = LegalizeOps()(CollapseSumLike)
@@ -1066,7 +1095,6 @@ def test_collapse_sum_to():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
[email protected]("TOPI collapse_sum not support symbolic now")
 def test_collapse_sum_to_symbolic():
     # fmt: off
     @tvm.script.ir_module
@@ -1077,6 +1105,38 @@ def test_collapse_sum_to_symbolic():
             gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1))
             return gv
 
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def collapse_sum(var_rxplaceholder: T.handle, var_rxplaceholder_red: 
T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            a, b, c = T.int64(), T.int64(), T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c))
+            rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, (b, 
T.int64(1)))
+            # with T.block("root"):
+            for ax0, ax1, k0, k2 in T.grid(b, T.int64(1), a, c):
+                with T.block("rxplaceholder_red"):
+                    v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, 
k0, k2])
+                    T.reads(rxplaceholder[v_k0, v_ax0, v_k2])
+                    T.writes(rxplaceholder_red[v_ax0, v_ax1])
+                    with T.init():
+                        rxplaceholder_red[v_ax0, v_ax1] = T.float32(0)
+                    rxplaceholder_red[v_ax0, v_ax1] = (
+                        rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, 
v_ax0, v_k2]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor(("a", "b", "c"), dtype="float32")
+        ) -> R.Tensor(("b", 1), dtype="float32"):
+            b = T.int64()
+            a = T.int64()
+            c = T.int64()
+            cls = Expected
+            gv = R.call_tir(
+                cls.collapse_sum, (x,), out_sinfo=R.Tensor((b, 1), 
dtype="float32")
+            )
+            return gv
     # fmt: on
 
     mod = LegalizeOps()(CollapseSumTo)
diff --git a/tests/python/topi/python/test_topi_reduce.py 
b/tests/python/topi/python/test_topi_reduce.py
index 71ce654913..8f9e416ffb 100644
--- a/tests/python/topi/python/test_topi_reduce.py
+++ b/tests/python/topi/python/test_topi_reduce.py
@@ -25,7 +25,7 @@ import tvm
 import tvm.testing
 import tvm.topi.testing
 
-from tvm import te, topi
+from tvm import te, topi, tir
 from tvm.topi.utils import get_const_tuple
 
 in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
@@ -192,10 +192,15 @@ def test_complex_reduce(target, dev):
     tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
 
 
+n = tir.Var("n", "int32")
+m = tir.Var("m", "int32")
+true_value_map = {n: 3, m: 5}
+
 data_shape, target_shape = tvm.testing.parameters(
     ((2, 3), (3,)),
     ((2, 3, 4), (2, 1, 4)),
     ((2, 3, 4, 5), (3, 1, 5)),
+    ((2, n, 4, m), (n, 1, m)),
 )
 
 
@@ -218,11 +223,16 @@ def test_collapse_sum(data_shape, target_shape):
     B = topi.collapse_sum(A, target_shape)
     s = te.create_schedule([B.op])
 
-    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
-    b_np = _my_npy_collapse_sum(a_np, target_shape)
+    data_shape_const = [int(s) if s not in true_value_map else 
true_value_map[s] for s in A.shape]
+    target_shape_const = [
+        int(s) if s not in true_value_map else true_value_map[s] for s in 
target_shape
+    ]
+    a_np = np.random.uniform(size=data_shape_const).astype(A.dtype)
+    b_np = _my_npy_collapse_sum(a_np, target_shape_const)
     dev = tvm.cpu(0)
     a = tvm.nd.array(a_np, dev)
-    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+    B_shape_const = [int(s) if s not in true_value_map else true_value_map[s] 
for s in B.shape]
+    b = tvm.nd.array(np.zeros(B_shape_const, dtype=B.dtype), dev)
     # Building with the CSE pass disabled
     with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["tir.CommonSubexprElimTIR"]):
         foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")

Reply via email to