This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new edb6b01b68 [Unity] Relax op: collapse sum (#14059)
edb6b01b68 is described below

commit edb6b01b683007f7cf702c85f6e4215365c7c646
Author: Chaosfan <[email protected]>
AuthorDate: Tue Feb 21 11:52:29 2023 +0800

    [Unity] Relax op: collapse sum (#14059)
    
    This PR brings high-level operators `relax.collapse_sum_like` and 
`relax.collapse_sum_to` which is useful when doing AD in Relax. To achieve 
this, it exposes the interface of `topi.collapse_sum`. Moreover, this PR also 
implements the legalization of these op and adds corresponding tests.
---
 python/tvm/relax/op/manipulate.py                  |  53 ++++
 .../tvm/relax/transform/legalize_ops/manipulate.py |   5 +
 python/tvm/script/ir_builder/relax/ir.py           |   4 +
 python/tvm/topi/reduction.py                       |  31 ++
 src/relax/op/tensor/manipulate.cc                  | 130 ++++++++
 src/relax/op/tensor/manipulate.h                   |  21 ++
 src/topi/reduction.cc                              |   4 +
 tests/python/relax/test_op_manipulate.py           | 326 +++++++++++++++++++++
 .../test_transform_legalize_ops_manipulate.py      | 103 +++++++
 .../relax/test_tvmscript_parser_op_manipulate.py   |  33 +++
 tests/python/topi/python/test_topi_reduce.py       |  39 +++
 11 files changed, 749 insertions(+)

diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index a46c62e1f1..25bf525191 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -261,3 +261,56 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] 
= None) -> Expr:
     if isinstance(axis, int):
         axis = [axis]
     return _ffi_api.squeeze(x, axis)  # type: ignore
+
+
+def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
+    """Return a summation of data to the shape of collapse_target.
+
+    For details, please see relax.op.collapse_sum_to.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor.
+
+    collapse_target : relax.Expr
+        The tensor whose shape is the shape to collapse to.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result tensor after summation.
+    """
+    return _ffi_api.collapse_sum_like(data, collapse_target)  # type: ignore
+
+
+def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> 
Expr:
+    """Return a summation of data to the given shape.
+
+    collapse_sum_to is intended as the backward operator of 
tvm.relax.op.broadcast_to and
+    other broadcast operators in the automatic differentiation process.
+
+    We expect that data is the result of broadcasting some tensor of the given 
shape in some
+    broadcast operation. Thus the given `shape` and `data.shape` must follow 
broadcast rules.
+
+    During computation, all axes of `data.shape` and `shape` are checked from 
right to left.
+    For an axis, if it follows these rules, `data` will be summed over this 
axis:
+    - the axis exists in `data.shape` but not in `shape`, or
+    - the axis exists in `data.shape` and equals to 1 in `shape`.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor.
+
+    shape : Union[Tuple[PrimExprLike], relax.Expr]
+        The shape to collapse to.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result tensor of the given shape after summation.
+    """
+    if isinstance(shape, (tuple, list)):
+        shape = ShapeExpr(shape)
+    return _ffi_api.collapse_sum_to(data, shape)  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 76e3e74bab..5b992eff1d 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -37,6 +37,11 @@ def _reshape(
 
 register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, 
"broadcast_to"))
 register_legalize("relax.reshape", _reshape(topi.reshape, "reshape"))
+register_legalize(
+    "relax.collapse_sum_like",
+    _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True),
+)
+register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, 
"collapse_sum"))
 
 
 @register_legalize("relax.concat")
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 7298b8c6e5..43918ce7ec 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -45,6 +45,8 @@ from tvm.relax.op import (
     call_tir,
     ceil,
     clip,
+    collapse_sum_like,
+    collapse_sum_to,
     concat,
     cos,
     cosh,
@@ -485,6 +487,8 @@ __all__ = [
     "call_builtin_with_ctx",
     "ceil",
     "clip",
+    "collapse_sum_like",
+    "collapse_sum_to",
     "concat",
     "cos",
     "cosh",
diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py
index 45d07af577..5045cb8174 100644
--- a/python/tvm/topi/reduction.py
+++ b/python/tvm/topi/reduction.py
@@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False):
     ret : tvm.te.Tensor
     """
     return cpp.prod(data, axis, keepdims)
+
+
+def collapse_sum(data, target_shape):
+    """Return a summation of data to the given shape.
+
+    collapse_sum is intended as the backward operator of topi broadcast 
operators in the automatic
+    differentiation process.
+
+    We expect that data is the result of broadcasting some tensor of 
target_shape in some
+    broadcast operation. Thus target_shape and data.shape must follow 
broadcast rules.
+
+    During computation, the axes of data.shape and target_shape are checked 
from right to left.
+    For every axis, if it either:
+    - exist in data but not in target_shape, or
+    - is larger than 1 in data and equals to 1 in target_shape,
+    data will be summed over this axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input tensor.
+
+    shape : Tuple[int]
+        The shape to collapse to.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+        The result tensor after summation.
+    """
+    return cpp.collapse_sum(data, target_shape)
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 8ce2a541da..e146a604af 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -839,5 +839,135 @@ TVM_REGISTER_OP("relax.squeeze")
     .add_argument("x", "Tensor", "The input tensor.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze);
 
+void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
+                        const Array<PrimExpr>& data_shape, const 
Array<PrimExpr>& target_shape) {
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+
+  int data_ndim = data_shape.size();
+  int target_ndim = target_shape.size();
+
+  int data_ax = data_ndim - 1;
+  int target_ax = target_ndim - 1;
+  for (; data_ax >= 0; --data_ax) {
+    if (target_ax < 0) {
+      continue;
+    }
+    const PrimExpr& dim0 = data_shape[data_ax];
+    const PrimExpr& dim1 = target_shape[target_ax];
+    const auto* int_dim0 = dim0.as<IntImmNode>();
+    const auto* int_dim1 = dim1.as<IntImmNode>();
+
+    if (analyzer->CanProveEqual(dim0, dim1) || (int_dim1 != nullptr && 
int_dim1->value == 1)) {
+      --target_ax;
+    } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "In " << call->op << ", the data shape at dim " << 
data_ax << " is "
+                       << dim0 << " and the target shape at dim " << target_ax 
<< " is " << dim1
+                       << ", which do not match the rule of collapse sum.");
+    } else {
+      // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we 
may need to revisit
+      // this requirement to reduce the workload of importers and better 
support dynamic shapes.
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << call->op
+                       << " fails to match the axes because of unknown dim or 
symbolic"
+                          " shape. In this position the dim of data shape is "
+                       << dim0 << " while the dim of target shape is " << dim1
+                       << ". If it is symbolic, consider use MatchCast 
first.");
+    }
+  }
+}
+
+/* relax.collapse_sum_like */
+Expr collapse_sum_like(Expr data, Expr collapse_target) {
+  static const Op& op = Op::Get("relax.collapse_sum_like");
+  return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like);
+
+StructInfo InferStructInfoCollapseSumLike(const Call& call, const 
BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo data_sinfo = input_sinfo[0];
+  TensorStructInfo collapse_target_sinfo = input_sinfo[1];
+
+  DataType output_dtype = data_sinfo->dtype;
+
+  Optional<Array<PrimExpr>> data_shape_value;
+  if (data_sinfo->shape.defined()) {
+    data_shape_value = 
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())->values;
+  }
+  Optional<Array<PrimExpr>> collapse_target_shape_value;
+  if (collapse_target_sinfo->shape.defined()) {
+    collapse_target_shape_value =
+        
GetStructInfoAs<ShapeStructInfoNode>(collapse_target_sinfo->shape.value())->values;
+  }
+
+  if (data_shape_value.defined() && collapse_target_shape_value.defined()) {
+    CheckCollapseShape(call, ctx, data_shape_value.value(), 
collapse_target_shape_value.value());
+  }
+
+  if (collapse_target_sinfo->shape.defined()) {
+    return TensorStructInfo(collapse_target_sinfo->shape.value(), 
output_dtype);
+  } else {
+    return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim);
+  }
+}
+
+TVM_REGISTER_OP("relax.collapse_sum_like")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("collapse_target", "Tensor",
+                  "The tensor whose shape is the shape to collapse to.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCollapseSumLike);
+
+/* relax.collapse_sum_to */
+Expr collapse_sum_to(Expr data, Expr shape) {
+  static const Op& op = Op::Get("relax.collapse_sum_to");
+  return Call(op, {std::move(data), std::move(shape)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to);
+
+StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "CollapseSumTo should have 2 
arguments");
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* shape_sinfo = 
GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "CollapseSumTo requires the input data to be a Tensor. However, the 
given one is "
+        << call->args[0]->struct_info_->GetTypeKey());
+  }
+  if (shape_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "CollapseSumTo requires the input shape to be a Shape. However, the 
given one is "
+        << call->args[1]->struct_info_->GetTypeKey());
+  }
+
+  DataType output_dtype = data_sinfo->dtype;
+
+  Optional<Array<PrimExpr>> data_shape_value;
+  if (data_sinfo->shape.defined()) {
+    data_shape_value = 
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())->values;
+  }
+
+  if (data_shape_value.defined() && shape_sinfo->values.defined()) {
+    CheckCollapseShape(call, ctx, data_shape_value.value(), 
shape_sinfo->values.value());
+  }
+
+  return TensorStructInfo(/*shape=*/call->args[1], output_dtype);
+}
+
+TVM_REGISTER_OP("relax.collapse_sum_to")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("shape", "Shape", "The shape to collapse to.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCollapseSumTo);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 6a2b23ecbd..95e29a3dce 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -112,6 +112,27 @@ Expr split(Expr x, ObjectRef indices_or_sections, int 
axis);
  */
 Expr squeeze(Expr x, Optional<Array<Integer>> axis);
 
+/*!
+ * \brief Return a summation of data to the shape of collapse_target.
+ * For details, please see the operator `relax.collapse_sum_to`.
+ * \param data The input tensor.
+ * \param collapse_target The tensor whose shape is the shape to collapse to.
+ * \return The result tensor after summation.
+ */
+Expr collapse_sum_like(Expr data, Expr collapse_target);
+
+/*!
+ * \brief Return a summation of data to the given shape.
+ * collapse_sum_to is intended as the backward operator of broadcast_to and
+ * other broadcast operators in the automatic differentiation process.
+ * We expect that data is the result of broadcasting some tensor of the given 
shape in some
+ * broadcast operation. Thus the given shape and data.shape must follow 
broadcast rules.
+ * \param data The input tensor.
+ * \param shape The shape to collapse to.
+ * \return The result tensor of the given shape after summation.
+ */
+Expr collapse_sum_to(Expr data, Expr shape);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc
index 3d1c6f9f7d..a9d692cc07 100644
--- a/src/topi/reduction.cc
+++ b/src/topi/reduction.cc
@@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, 
TVMRetValue* rv) {
   *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
 });
 
+TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  *rv = topi::collapse_sum(args[0], args[1]);
+});
+
 }  // namespace topi
 }  // namespace tvm
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 6c7727b7d5..abb414b472 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -36,6 +36,9 @@ def test_op_correctness():
     assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, 
a)).op == Op.get(
         "relax.layout_transform"
     )
+    assert relax.op.collapse_sum_to(x, (4, 5)).op == 
Op.get("relax.collapse_sum_to")
+    y = relax.Var("x", R.Tensor((4, 5), "float32"))
+    assert relax.op.collapse_sum_like(x, y).op == 
Op.get("relax.collapse_sum_like")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -2378,5 +2381,328 @@ def 
test_broadcast_to_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.broadcast_to(x1, stgt))
 
 
+def test_collapse_sum_like_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor())
+    y0 = relax.Var("y", R.Tensor((3, 4), "float32"))
+    y1 = relax.Var("y", R.Tensor("float32", ndim=2))
+    y2 = relax.Var("y", R.Tensor("float32"))
+    y3 = relax.Var("y", R.Tensor((3, 4)))
+    y4 = relax.Var("y", R.Tensor(ndim=2))
+    y5 = relax.Var("y", R.Tensor((1, 4)))
+
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x1, y1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y2), 
relax.TensorStructInfo(dtype="float32", ndim=-1)
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4), 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4), 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x2, y4), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x4, y1), 
relax.TensorStructInfo(dtype="", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4), 
dtype="")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4), 
"float32")
+    )
+
+
+def test_collapse_sum_like_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    x0 = relax.Var("x", R.Tensor((3, 4, a), "float32"))
+    y0 = relax.Var("y", R.Tensor((4, a), "float32"))
+    x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32"))
+    y1 = relax.Var("x", R.Tensor((1, a + b), "float32"))
+
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a), 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a + 
b), "float32")
+    )
+
+
+def test_collapse_sum_like_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4)))
+    s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+    s2 = relax.Var("s2", relax.ShapeStructInfo())
+    s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4)))
+    s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2))
+    s5 = relax.Var("s5", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32"))
+    y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32"))
+    y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32"))
+
+    _check_inference(bb, relax.op.collapse_sum_like(x0, y0), 
relax.TensorStructInfo(s3, "float32"))
+    _check_inference(bb, relax.op.collapse_sum_like(x1, y1), 
relax.TensorStructInfo(s4, "float32"))
+    _check_inference(bb, relax.op.collapse_sum_like(x2, y2), 
relax.TensorStructInfo(s5, "float32"))
+
+
+def test_collapse_sum_like_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+    y0 = relax.Var("y", R.Tensor((3, 4), "float16"))
+    y1 = relax.Var("y", R.Tensor((3, 4), "int8"))
+
+    _check_inference(
+        bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), 
"float16")
+    )
+    _check_inference(bb, relax.op.collapse_sum_like(x1, y1), 
relax.TensorStructInfo((3, 4), "int8"))
+
+
+def test_collapse_sum_like_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    x1 = relax.Var("x", relax.ShapeStructInfo((4, 5)))
+    x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_like(x0, x1))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_like(x2, x0))
+
+
+def test_collapse_sum_like_infer_struct_info_shape_mismatch():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32"))
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    x1 = relax.Var("z", R.Tensor((3, a, 5), "float32"))
+    y1 = relax.Var("w", R.Tensor((3, b, 5), "float32"))
+
+    s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5)))
+    s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5)))
+    x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32"))
+
+    s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5)))
+    s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5)))
+    x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_like(x0, y0))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_like(x1, y1))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_like(x2, y2))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_like(x3, y3))
+
+
+def test_collapse_sum_to_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor())
+
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 
4), "float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 
4), "float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3, 
4), "float32")
+    )
+    _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), 
relax.TensorStructInfo((3, 4), ""))
+    _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), 
relax.TensorStructInfo((3, 4), ""))
+    _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), 
relax.TensorStructInfo((3, 4), ""))
+
+
+def test_collapse_sum_to_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    x0 = relax.Var("x", R.Tensor((3, 4, a), "float32"))
+    x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32"))
+
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4, 
a), "float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, (1, a + b)), 
relax.TensorStructInfo((1, a + b), "float32")
+    )
+
+
+def test_collapse_sum_to_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4)))
+    s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+    s2 = relax.Var("s2", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 
4), "float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 
4), "float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 
4), "float32")
+    )
+
+
+def test_collapse_sum_to_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 
4), "float16")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 
4), "int8")
+    )
+
+
+def test_collapse_sum_to_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    x1 = relax.Var("x", relax.ShapeStructInfo((4, 5)))
+    x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x0, x0))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x0, x2))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x1, x1))
+
+
+def test_collapse_sum_to_infer_struct_info_shape_mismatch():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    x1 = relax.Var("x", R.Tensor((3, a, 5), "float32"))
+
+    s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5)))
+    x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+
+    s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5)))
+    x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x1, (3, b, 5)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x2, (4, 4, 5)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5)))
+
+
+def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    d = tir.Var("d", "int64")
+    s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b)))
+    s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+    s2 = relax.Var("s2", relax.ShapeStructInfo())
+    x0 = relax.Var("x", R.Tensor((3, a, b), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor(""))
+    x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b)))
+    stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2))
+    stgt2 = relax.Var("stgt2", relax.ShapeStructInfo())
+
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0, 
"float32")
+    )
+    _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), 
relax.TensorStructInfo(stgt0, ""))
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1, 
"float32")
+    )
+    _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), 
relax.TensorStructInfo(stgt1, ""))
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2, 
"float32")
+    )
+    _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), 
relax.TensorStructInfo(stgt2, ""))
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2, 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2, 
"float32")
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 2a30994b83..8743261ee7 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -785,5 +785,108 @@ def test_squeeze_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_collapse_sum_like():
+    # fmt: off
+    @tvm.script.ir_module
+    class CollapseSumLike:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), 
"float32")) -> R.Tensor((1, 3), "float32"):
+            gv: R.Tensor((1, 3), "float32") = R.collapse_sum_like(x, y)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), 
"float32")) -> R.Tensor((1, 3), "float32"):
+            gv = R.call_tir(collapse_sum, (x,), R.Tensor((1, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def collapse_sum(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), 
"float32"], rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(3)), "float32"]):
+            T.func_attr({"tir.noalias": True})
+            for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)):
+                with T.block("rxplaceholder_red"):
+                    ax0, ax1, k0 = T.axis.remap("SSR", [i0, i1, i2])
+                    T.reads(rxplaceholder[k0, ax1])
+                    T.writes(rxplaceholder_red[ax0, ax1])
+                    with T.init():
+                        rxplaceholder_red[ax0, ax1] = T.float32(0)
+                    rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] 
+ rxplaceholder[k0, ax1]
+    # fmt: on
+
+    mod = LegalizeOps()(CollapseSumLike)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected]("TOPI collapse_sum not support symbolic now")
+def test_collapse_sum_like_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class CollapseSumLike:
+        @R.function
+        def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b", 
1), "float32")) -> R.Tensor(("b", 1), "float32"):
+            b = T.var("int64")
+            gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y)
+            return gv
+
+    # fmt: on
+
+    mod = LegalizeOps()(CollapseSumLike)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_collapse_sum_to():
+    # fmt: off
+    @tvm.script.ir_module
+    class CollapseSumTo:
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), "float32")) -> R.Tensor((2, 1), 
"float32"):
+            gv: R.Tensor((2, 1), "float32") = R.collapse_sum_to(x, (2, 1))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((3, 2, 3), dtype="float32")
+        ) -> R.Tensor((2, 1), dtype="float32"):
+            # block 0
+            gv = R.call_tir(collapse_sum, (x,), R.Tensor((2, 1), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def collapse_sum(rxplaceholder: T.Buffer[(T.int64(3), T.int64(2), 
T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(2), T.int64(1)), 
"float32"]):
+            T.func_attr({"tir.noalias": True})
+            for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), 
T.int64(3)):
+                with T.block("rxplaceholder_red"):
+                    v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, 
k0, k2])
+                    T.reads(rxplaceholder[v_k0, v_ax0, v_k2])
+                    T.writes(rxplaceholder_red[v_ax0, v_ax1])
+                    with T.init():
+                        rxplaceholder_red[v_ax0, v_ax1] = T.float32(0)
+                    rxplaceholder_red[v_ax0, v_ax1] = 
(rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2])
+    # fmt: on
+
+    mod = LegalizeOps()(CollapseSumTo)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected]("TOPI collapse_sum not support symbolic now")
+def test_collapse_sum_to_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class CollapseSumTo:
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("b", 
1), "float32"):
+            b = T.var("int64")
+            gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1))
+            return gv
+
+    # fmt: on
+
+    mod = LegalizeOps()(CollapseSumTo)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py 
b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index 27f089ee67..c1d0c90d34 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -310,5 +310,38 @@ def test_squeeze_with_indices():
     _check(foo, bb.get()["foo"])
 
 
+def test_collapse_sum_like():
+    @R.function
+    def foo(
+        x: R.Tensor((3, 4, 5), "float32"), y: R.Tensor((4, 5), "float32")
+    ) -> R.Tensor((4, 5), "float32"):
+        gv: R.Tensor((4, 5), "float32") = R.collapse_sum_like(x, y)
+        return gv
+
+    x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    y = relax.Var("y", R.Tensor((4, 5), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, y]):
+        gv = bb.emit(relax.op.collapse_sum_like(x, y))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_collapse_sum_to():
+    @R.function
+    def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((4, 5), "float32"):
+        gv: R.Tensor((4, 5), "float32") = R.collapse_sum_to(x, (4, 5))
+        return gv
+
+    x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.collapse_sum_to(x, (4, 5)))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/topi/python/test_topi_reduce.py 
b/tests/python/topi/python/test_topi_reduce.py
index e7f47ba0c4..0f585fec96 100644
--- a/tests/python/topi/python/test_topi_reduce.py
+++ b/tests/python/topi/python/test_topi_reduce.py
@@ -26,6 +26,7 @@ import tvm.testing
 import tvm.topi.testing
 
 from tvm import te, topi
+from tvm.topi.utils import get_const_tuple
 
 in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
     ((32,), 0, False, "argmax", "float32"),
@@ -183,5 +184,43 @@ def test_complex_reduce(target, dev):
     tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
 
 
+data_shape, target_shape = tvm.testing.parameters(
+    ((2, 3), (3,)),
+    ((2, 3, 4), (2, 1, 4)),
+    ((2, 3, 4, 5), (3, 1, 5)),
+)
+
+
+def _my_npy_collapse_sum(data, target_shape):
+    reduce_axes = []
+    i = data.ndim - 1
+    j = len(target_shape) - 1
+    while i >= 0:
+        if j < 0:
+            reduce_axes.append(i)
+        elif target_shape[j] == 1 and data.shape[i] > 1:
+            reduce_axes.append(i)
+        i -= 1
+        j -= 1
+    return np.sum(data, tuple(reduce_axes)).reshape(target_shape)
+
+
+def test_collapse_sum(data_shape, target_shape):
+    A = te.placeholder(data_shape, name="A")
+    B = topi.collapse_sum(A, target_shape)
+    s = te.create_schedule([B.op])
+
+    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
+    b_np = _my_npy_collapse_sum(a_np, target_shape)
+    dev = tvm.cpu(0)
+    a = tvm.nd.array(a_np, dev)
+    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+    # Building with the CSE pass disabled
+    with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["tir.CommonSubexprElimTIR"]):
+        foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")
+    foo(a, b)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to