This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 15f9be5449 [TOPI] Expose `topi::collapse_sum` to Python and support
symbolic shape (#14541)
15f9be5449 is described below
commit 15f9be5449a1a351b65f327a3a5e2aaf578a7161
Author: Chaofan Lin <[email protected]>
AuthorDate: Sun Apr 9 20:13:57 2023 +0800
[TOPI] Expose `topi::collapse_sum` to Python and support symbolic shape
(#14541)
TOPI has an implementation of collapse_sum internally
(tvm/topi/reduction.h) but it is not exposed to FFI and can not be called in
Python side. This patch exposes it and adds related tests. And this PR lets the
implementation of topi::collapse_sum support symbolic shape cases.
---
include/tvm/topi/reduction.h | 19 +++++++----
python/tvm/topi/reduction.py | 31 +++++++++++++++++
src/topi/reduction.cc | 4 +++
tests/python/topi/python/test_topi_reduce.py | 50 +++++++++++++++++++++++++++-
4 files changed, 97 insertions(+), 7 deletions(-)
diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h
index 169ae010aa..9090ae9707 100644
--- a/include/tvm/topi/reduction.h
+++ b/include/tvm/topi/reduction.h
@@ -333,21 +333,28 @@ inline Tensor sum(const Tensor& data, const
Array<Integer>& axis, bool keepdims
}
inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
- ICHECK_GE(data->shape.size(), target_shape.size());
- auto ishape = detail::GetConstIntValues(data->shape, "ishape");
- auto oshape = detail::GetConstIntValues(target_shape, "oshape");
+ const auto& ishape = data->shape;
+ const auto& oshape = target_shape;
+ int isize = data->shape.size();
+ int osize = target_shape.size();
+
+ ICHECK_GE(isize, osize)
+ << "Invalid collapse: input dimensionality smaller than output
dimensionality.\ninput shape: "
+ << data->shape << "\nvs\noutput shape: " << target_shape;
std::vector<int> reduce_axes;
std::vector<int> squeeze_axes;
- for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0;
--i_ax) {
- if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
+ tvm::PrimExpr one(1);
+
+ for (int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) {
+ if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) {
--o_ax;
continue;
}
reduce_axes.push_back(i_ax);
if (o_ax < 0) { // squeeze o_ax if was added during expansion
squeeze_axes.push_back(i_ax);
- } else if (oshape[o_ax] == 1) {
+ } else if (topi::detail::EqualCheck(one, oshape[o_ax])) {
--o_ax;
}
}
diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py
index 45d07af577..5045cb8174 100644
--- a/python/tvm/topi/reduction.py
+++ b/python/tvm/topi/reduction.py
@@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False):
ret : tvm.te.Tensor
"""
return cpp.prod(data, axis, keepdims)
+
+
+def collapse_sum(data, target_shape):
+ """Return a summation of data to the given shape.
+
+ collapse_sum is intended as the backward operator of topi broadcast
operators in the automatic
+ differentiation process.
+
+ We expect that data is the result of broadcasting some tensor of
target_shape in some
+ broadcast operation. Thus target_shape and data.shape must follow
broadcast rules.
+
+ During computation, the axes of data.shape and target_shape are checked
from right to left.
+ For every axis, if it either:
+ - exist in data but not in target_shape, or
+ - is larger than 1 in data and equals to 1 in target_shape,
+ data will be summed over this axis.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The input tensor.
+
+ shape : Tuple[int]
+ The shape to collapse to.
+
+ Returns
+ -------
+ ret : tvm.te.Tensor
+ The result tensor after summation.
+ """
+ return cpp.collapse_sum(data, target_shape)
diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc
index 3d1c6f9f7d..a9d692cc07 100644
--- a/src/topi/reduction.cc
+++ b/src/topi/reduction.cc
@@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args,
TVMRetValue* rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});
+TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args,
TVMRetValue* rv) {
+ *rv = topi::collapse_sum(args[0], args[1]);
+});
+
} // namespace topi
} // namespace tvm
diff --git a/tests/python/topi/python/test_topi_reduce.py
b/tests/python/topi/python/test_topi_reduce.py
index 3c4c170d0d..8e45ae9a6e 100644
--- a/tests/python/topi/python/test_topi_reduce.py
+++ b/tests/python/topi/python/test_topi_reduce.py
@@ -25,7 +25,7 @@ import tvm
import tvm.testing
import tvm.topi.testing
-from tvm import te, topi
+from tvm import te, topi, tir
in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
((32,), 0, False, "argmax", "float32"),
@@ -191,5 +191,53 @@ def test_complex_reduce(target, dev):
tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
+n = tir.Var("n", "int32")
+m = tir.Var("m", "int32")
+true_value_map = {n: 3, m: 5}
+
+data_shape, target_shape = tvm.testing.parameters(
+ ((2, 3), (3,)),
+ ((2, 3, 4), (2, 1, 4)),
+ ((2, 3, 4, 5), (3, 1, 5)),
+ ((2, n, 4, m), (n, 1, m)),
+)
+
+
+def _my_npy_collapse_sum(data, target_shape):
+ reduce_axes = []
+ i = data.ndim - 1
+ j = len(target_shape) - 1
+ while i >= 0:
+ if j < 0:
+ reduce_axes.append(i)
+ elif target_shape[j] == 1 and data.shape[i] > 1:
+ reduce_axes.append(i)
+ i -= 1
+ j -= 1
+ return np.sum(data, tuple(reduce_axes)).reshape(target_shape)
+
+
+def test_collapse_sum(data_shape, target_shape):
+ A = te.placeholder(data_shape, name="A")
+ B = topi.collapse_sum(A, target_shape)
+ s = te.create_schedule([B.op])
+
+ data_shape_const = [int(s) if s not in true_value_map else
true_value_map[s] for s in A.shape]
+ target_shape_const = [
+ int(s) if s not in true_value_map else true_value_map[s] for s in
target_shape
+ ]
+ a_np = np.random.uniform(size=data_shape_const).astype(A.dtype)
+ b_np = _my_npy_collapse_sum(a_np, target_shape_const)
+ dev = tvm.cpu(0)
+ a = tvm.nd.array(a_np, dev)
+ B_shape_const = [int(s) if s not in true_value_map else true_value_map[s]
for s in B.shape]
+ b = tvm.nd.array(np.zeros(B_shape_const, dtype=B.dtype), dev)
+ # Building with the CSE pass disabled
+ with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
+ foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")
+ foo(a, b)
+ tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()