This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new edb6b01b68 [Unity] Relax op: collapse sum (#14059)
edb6b01b68 is described below
commit edb6b01b683007f7cf702c85f6e4215365c7c646
Author: Chaosfan <[email protected]>
AuthorDate: Tue Feb 21 11:52:29 2023 +0800
[Unity] Relax op: collapse sum (#14059)
This PR brings high-level operators `relax.collapse_sum_like` and
`relax.collapse_sum_to` which is useful when doing AD in Relax. To achieve
this, it exposes the interface of `topi.collapse_sum`. Moreover, this PR also
implements the legalization of these op and adds corresponding tests.
---
python/tvm/relax/op/manipulate.py | 53 ++++
.../tvm/relax/transform/legalize_ops/manipulate.py | 5 +
python/tvm/script/ir_builder/relax/ir.py | 4 +
python/tvm/topi/reduction.py | 31 ++
src/relax/op/tensor/manipulate.cc | 130 ++++++++
src/relax/op/tensor/manipulate.h | 21 ++
src/topi/reduction.cc | 4 +
tests/python/relax/test_op_manipulate.py | 326 +++++++++++++++++++++
.../test_transform_legalize_ops_manipulate.py | 103 +++++++
.../relax/test_tvmscript_parser_op_manipulate.py | 33 +++
tests/python/topi/python/test_topi_reduce.py | 39 +++
11 files changed, 749 insertions(+)
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index a46c62e1f1..25bf525191 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -261,3 +261,56 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]]
= None) -> Expr:
if isinstance(axis, int):
axis = [axis]
return _ffi_api.squeeze(x, axis) # type: ignore
+
+
+def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
+ """Return a summation of data to the shape of collapse_target.
+
+ For details, please see relax.op.collapse_sum_to.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor.
+
+ collapse_target : relax.Expr
+ The tensor whose shape is the shape to collapse to.
+
+ Returns
+ -------
+ result : relax.Expr
+ The result tensor after summation.
+ """
+ return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore
+
+
+def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) ->
Expr:
+ """Return a summation of data to the given shape.
+
+ collapse_sum_to is intended as the backward operator of
tvm.relax.op.broadcast_to and
+ other broadcast operators in the automatic differentiation process.
+
+ We expect that data is the result of broadcasting some tensor of the given
shape in some
+ broadcast operation. Thus the given `shape` and `data.shape` must follow
broadcast rules.
+
+ During computation, all axes of `data.shape` and `shape` are checked from
right to left.
+ For an axis, if it follows these rules, `data` will be summed over this
axis:
+ - the axis exists in `data.shape` but not in `shape`, or
+ - the axis exists in `data.shape` and equals to 1 in `shape`.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor.
+
+ shape : Union[Tuple[PrimExprLike], relax.Expr]
+ The shape to collapse to.
+
+ Returns
+ -------
+ result : relax.Expr
+ The result tensor of the given shape after summation.
+ """
+ if isinstance(shape, (tuple, list)):
+ shape = ShapeExpr(shape)
+ return _ffi_api.collapse_sum_to(data, shape) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 76e3e74bab..5b992eff1d 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -37,6 +37,11 @@ def _reshape(
register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to,
"broadcast_to"))
register_legalize("relax.reshape", _reshape(topi.reshape, "reshape"))
+register_legalize(
+ "relax.collapse_sum_like",
+ _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True),
+)
+register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum,
"collapse_sum"))
@register_legalize("relax.concat")
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 7298b8c6e5..43918ce7ec 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -45,6 +45,8 @@ from tvm.relax.op import (
call_tir,
ceil,
clip,
+ collapse_sum_like,
+ collapse_sum_to,
concat,
cos,
cosh,
@@ -485,6 +487,8 @@ __all__ = [
"call_builtin_with_ctx",
"ceil",
"clip",
+ "collapse_sum_like",
+ "collapse_sum_to",
"concat",
"cos",
"cosh",
diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py
index 45d07af577..5045cb8174 100644
--- a/python/tvm/topi/reduction.py
+++ b/python/tvm/topi/reduction.py
@@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False):
ret : tvm.te.Tensor
"""
return cpp.prod(data, axis, keepdims)
+
+
+def collapse_sum(data, target_shape):
+ """Return a summation of data to the given shape.
+
+ collapse_sum is intended as the backward operator of topi broadcast
operators in the automatic
+ differentiation process.
+
+ We expect that data is the result of broadcasting some tensor of
target_shape in some
+ broadcast operation. Thus target_shape and data.shape must follow
broadcast rules.
+
+ During computation, the axes of data.shape and target_shape are checked
from right to left.
+ For every axis, if it either:
+ - exist in data but not in target_shape, or
+ - is larger than 1 in data and equals to 1 in target_shape,
+ data will be summed over this axis.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The input tensor.
+
+ shape : Tuple[int]
+ The shape to collapse to.
+
+ Returns
+ -------
+ ret : tvm.te.Tensor
+ The result tensor after summation.
+ """
+ return cpp.collapse_sum(data, target_shape)
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 8ce2a541da..e146a604af 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -839,5 +839,135 @@ TVM_REGISTER_OP("relax.squeeze")
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze);
+void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
+ const Array<PrimExpr>& data_shape, const
Array<PrimExpr>& target_shape) {
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+
+ int data_ndim = data_shape.size();
+ int target_ndim = target_shape.size();
+
+ int data_ax = data_ndim - 1;
+ int target_ax = target_ndim - 1;
+ for (; data_ax >= 0; --data_ax) {
+ if (target_ax < 0) {
+ continue;
+ }
+ const PrimExpr& dim0 = data_shape[data_ax];
+ const PrimExpr& dim1 = target_shape[target_ax];
+ const auto* int_dim0 = dim0.as<IntImmNode>();
+ const auto* int_dim1 = dim1.as<IntImmNode>();
+
+ if (analyzer->CanProveEqual(dim0, dim1) || (int_dim1 != nullptr &&
int_dim1->value == 1)) {
+ --target_ax;
+ } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "In " << call->op << ", the data shape at dim " <<
data_ax << " is "
+ << dim0 << " and the target shape at dim " << target_ax
<< " is " << dim1
+ << ", which do not match the rule of collapse sum.");
+ } else {
+ // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we
may need to revisit
+ // this requirement to reduce the workload of importers and better
support dynamic shapes.
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << call->op
+ << " fails to match the axes because of unknown dim or
symbolic"
+ " shape. In this position the dim of data shape is "
+ << dim0 << " while the dim of target shape is " << dim1
+ << ". If it is symbolic, consider use MatchCast
first.");
+ }
+ }
+}
+
+/* relax.collapse_sum_like */
+Expr collapse_sum_like(Expr data, Expr collapse_target) {
+ static const Op& op = Op::Get("relax.collapse_sum_like");
+ return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like);
+
+StructInfo InferStructInfoCollapseSumLike(const Call& call, const
BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ TensorStructInfo data_sinfo = input_sinfo[0];
+ TensorStructInfo collapse_target_sinfo = input_sinfo[1];
+
+ DataType output_dtype = data_sinfo->dtype;
+
+ Optional<Array<PrimExpr>> data_shape_value;
+ if (data_sinfo->shape.defined()) {
+ data_shape_value =
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())->values;
+ }
+ Optional<Array<PrimExpr>> collapse_target_shape_value;
+ if (collapse_target_sinfo->shape.defined()) {
+ collapse_target_shape_value =
+
GetStructInfoAs<ShapeStructInfoNode>(collapse_target_sinfo->shape.value())->values;
+ }
+
+ if (data_shape_value.defined() && collapse_target_shape_value.defined()) {
+ CheckCollapseShape(call, ctx, data_shape_value.value(),
collapse_target_shape_value.value());
+ }
+
+ if (collapse_target_sinfo->shape.defined()) {
+ return TensorStructInfo(collapse_target_sinfo->shape.value(),
output_dtype);
+ } else {
+ return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim);
+ }
+}
+
+TVM_REGISTER_OP("relax.collapse_sum_like")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("collapse_target", "Tensor",
+ "The tensor whose shape is the shape to collapse to.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCollapseSumLike);
+
+/* relax.collapse_sum_to */
+Expr collapse_sum_to(Expr data, Expr shape) {
+ static const Op& op = Op::Get("relax.collapse_sum_to");
+ return Call(op, {std::move(data), std::move(shape)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to);
+
+StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder&
ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "CollapseSumTo should have 2
arguments");
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* shape_sinfo =
GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "CollapseSumTo requires the input data to be a Tensor. However, the
given one is "
+ << call->args[0]->struct_info_->GetTypeKey());
+ }
+ if (shape_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "CollapseSumTo requires the input shape to be a Shape. However, the
given one is "
+ << call->args[1]->struct_info_->GetTypeKey());
+ }
+
+ DataType output_dtype = data_sinfo->dtype;
+
+ Optional<Array<PrimExpr>> data_shape_value;
+ if (data_sinfo->shape.defined()) {
+ data_shape_value =
GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())->values;
+ }
+
+ if (data_shape_value.defined() && shape_sinfo->values.defined()) {
+ CheckCollapseShape(call, ctx, data_shape_value.value(),
shape_sinfo->values.value());
+ }
+
+ return TensorStructInfo(/*shape=*/call->args[1], output_dtype);
+}
+
+TVM_REGISTER_OP("relax.collapse_sum_to")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("shape", "Shape", "The shape to collapse to.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCollapseSumTo);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 6a2b23ecbd..95e29a3dce 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -112,6 +112,27 @@ Expr split(Expr x, ObjectRef indices_or_sections, int
axis);
*/
Expr squeeze(Expr x, Optional<Array<Integer>> axis);
+/*!
+ * \brief Return a summation of data to the shape of collapse_target.
+ * For details, please see the operator `relax.collapse_sum_to`.
+ * \param data The input tensor.
+ * \param collapse_target The tensor whose shape is the shape to collapse to.
+ * \return The result tensor after summation.
+ */
+Expr collapse_sum_like(Expr data, Expr collapse_target);
+
+/*!
+ * \brief Return a summation of data to the given shape.
+ * collapse_sum_to is intended as the backward operator of broadcast_to and
+ * other broadcast operators in the automatic differentiation process.
+ * We expect that data is the result of broadcasting some tensor of the given
shape in some
+ * broadcast operation. Thus the given shape and data.shape must follow
broadcast rules.
+ * \param data The input tensor.
+ * \param shape The shape to collapse to.
+ * \return The result tensor of the given shape after summation.
+ */
+Expr collapse_sum_to(Expr data, Expr shape);
+
} // namespace relax
} // namespace tvm
diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc
index 3d1c6f9f7d..a9d692cc07 100644
--- a/src/topi/reduction.cc
+++ b/src/topi/reduction.cc
@@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args,
TVMRetValue* rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});
+TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args,
TVMRetValue* rv) {
+ *rv = topi::collapse_sum(args[0], args[1]);
+});
+
} // namespace topi
} // namespace tvm
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 6c7727b7d5..abb414b472 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -36,6 +36,9 @@ def test_op_correctness():
assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c,
a)).op == Op.get(
"relax.layout_transform"
)
+ assert relax.op.collapse_sum_to(x, (4, 5)).op ==
Op.get("relax.collapse_sum_to")
+ y = relax.Var("x", R.Tensor((4, 5), "float32"))
+ assert relax.op.collapse_sum_like(x, y).op ==
Op.get("relax.collapse_sum_like")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -2378,5 +2381,328 @@ def
test_broadcast_to_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.broadcast_to(x1, stgt))
+def test_collapse_sum_like_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor())
+ y0 = relax.Var("y", R.Tensor((3, 4), "float32"))
+ y1 = relax.Var("y", R.Tensor("float32", ndim=2))
+ y2 = relax.Var("y", R.Tensor("float32"))
+ y3 = relax.Var("y", R.Tensor((3, 4)))
+ y4 = relax.Var("y", R.Tensor(ndim=2))
+ y5 = relax.Var("y", R.Tensor((1, 4)))
+
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4),
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x1, y1),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y1),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y2),
relax.TensorStructInfo(dtype="float32", ndim=-1)
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4),
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4),
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x2, y4),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x4, y1),
relax.TensorStructInfo(dtype="", ndim=2)
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4),
dtype="")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4),
"float32")
+ )
+
+
+def test_collapse_sum_like_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ x0 = relax.Var("x", R.Tensor((3, 4, a), "float32"))
+ y0 = relax.Var("y", R.Tensor((4, a), "float32"))
+ x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32"))
+ y1 = relax.Var("x", R.Tensor((1, a + b), "float32"))
+
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a),
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a +
b), "float32")
+ )
+
+
+def test_collapse_sum_like_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4)))
+ s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+ s2 = relax.Var("s2", relax.ShapeStructInfo())
+ s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4)))
+ s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2))
+ s5 = relax.Var("s5", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+ y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32"))
+ y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32"))
+ y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32"))
+
+ _check_inference(bb, relax.op.collapse_sum_like(x0, y0),
relax.TensorStructInfo(s3, "float32"))
+ _check_inference(bb, relax.op.collapse_sum_like(x1, y1),
relax.TensorStructInfo(s4, "float32"))
+ _check_inference(bb, relax.op.collapse_sum_like(x2, y2),
relax.TensorStructInfo(s5, "float32"))
+
+
+def test_collapse_sum_like_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+ y0 = relax.Var("y", R.Tensor((3, 4), "float16"))
+ y1 = relax.Var("y", R.Tensor((3, 4), "int8"))
+
+ _check_inference(
+ bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4),
"float16")
+ )
+ _check_inference(bb, relax.op.collapse_sum_like(x1, y1),
relax.TensorStructInfo((3, 4), "int8"))
+
+
+def test_collapse_sum_like_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ x1 = relax.Var("x", relax.ShapeStructInfo((4, 5)))
+ x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_like(x0, x1))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_like(x2, x0))
+
+
+def test_collapse_sum_like_infer_struct_info_shape_mismatch():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32"))
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ x1 = relax.Var("z", R.Tensor((3, a, 5), "float32"))
+ y1 = relax.Var("w", R.Tensor((3, b, 5), "float32"))
+
+ s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5)))
+ s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5)))
+ x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32"))
+
+ s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5)))
+ s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5)))
+ x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+ y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_like(x0, y0))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_like(x1, y1))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_like(x2, y2))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_like(x3, y3))
+
+
+def test_collapse_sum_to_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4)))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor())
+
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3,
4), "float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3,
4), "float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3,
4), "float32")
+ )
+ _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)),
relax.TensorStructInfo((3, 4), ""))
+ _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)),
relax.TensorStructInfo((3, 4), ""))
+ _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)),
relax.TensorStructInfo((3, 4), ""))
+
+
+def test_collapse_sum_to_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ x0 = relax.Var("x", R.Tensor((3, 4, a), "float32"))
+ x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32"))
+
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4,
a), "float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, (1, a + b)),
relax.TensorStructInfo((1, a + b), "float32")
+ )
+
+
+def test_collapse_sum_to_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4)))
+ s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+ s2 = relax.Var("s2", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3,
4), "float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3,
4), "float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3,
4), "float32")
+ )
+
+
+def test_collapse_sum_to_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3,
4), "float16")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3,
4), "int8")
+ )
+
+
+def test_collapse_sum_to_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ x1 = relax.Var("x", relax.ShapeStructInfo((4, 5)))
+ x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x0, x0))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x0, x2))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x1, x1))
+
+
+def test_collapse_sum_to_infer_struct_info_shape_mismatch():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ x1 = relax.Var("x", R.Tensor((3, a, 5), "float32"))
+
+ s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5)))
+ x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+
+ s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5)))
+ x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5)))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x1, (3, b, 5)))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x2, (4, 4, 5)))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5)))
+
+
+def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ c = tir.Var("c", "int64")
+ d = tir.Var("d", "int64")
+ s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b)))
+ s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+ s2 = relax.Var("s2", relax.ShapeStructInfo())
+ x0 = relax.Var("x", R.Tensor((3, a, b), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor(""))
+ x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+ stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b)))
+ stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2))
+ stgt2 = relax.Var("stgt2", relax.ShapeStructInfo())
+
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0,
"float32")
+ )
+ _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0),
relax.TensorStructInfo(stgt0, ""))
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1,
"float32")
+ )
+ _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1),
relax.TensorStructInfo(stgt1, ""))
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2,
"float32")
+ )
+ _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2),
relax.TensorStructInfo(stgt2, ""))
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2,
"float32")
+ )
+ _check_inference(
+ bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2,
"float32")
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 2a30994b83..8743261ee7 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -785,5 +785,108 @@ def test_squeeze_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_collapse_sum_like():
+ # fmt: off
+ @tvm.script.ir_module
+ class CollapseSumLike:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3),
"float32")) -> R.Tensor((1, 3), "float32"):
+ gv: R.Tensor((1, 3), "float32") = R.collapse_sum_like(x, y)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3),
"float32")) -> R.Tensor((1, 3), "float32"):
+ gv = R.call_tir(collapse_sum, (x,), R.Tensor((1, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def collapse_sum(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)),
"float32"], rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(3)), "float32"]):
+ T.func_attr({"tir.noalias": True})
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)):
+ with T.block("rxplaceholder_red"):
+ ax0, ax1, k0 = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(rxplaceholder[k0, ax1])
+ T.writes(rxplaceholder_red[ax0, ax1])
+ with T.init():
+ rxplaceholder_red[ax0, ax1] = T.float32(0)
+ rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1]
+ rxplaceholder[k0, ax1]
+ # fmt: on
+
+ mod = LegalizeOps()(CollapseSumLike)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected]("TOPI collapse_sum not support symbolic now")
+def test_collapse_sum_like_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class CollapseSumLike:
+ @R.function
+ def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b",
1), "float32")) -> R.Tensor(("b", 1), "float32"):
+ b = T.var("int64")
+ gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y)
+ return gv
+
+ # fmt: on
+
+ mod = LegalizeOps()(CollapseSumLike)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_collapse_sum_to():
+ # fmt: off
+ @tvm.script.ir_module
+ class CollapseSumTo:
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), "float32")) -> R.Tensor((2, 1),
"float32"):
+ gv: R.Tensor((2, 1), "float32") = R.collapse_sum_to(x, (2, 1))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((3, 2, 3), dtype="float32")
+ ) -> R.Tensor((2, 1), dtype="float32"):
+ # block 0
+ gv = R.call_tir(collapse_sum, (x,), R.Tensor((2, 1),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def collapse_sum(rxplaceholder: T.Buffer[(T.int64(3), T.int64(2),
T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(2), T.int64(1)),
"float32"]):
+ T.func_attr({"tir.noalias": True})
+ for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3),
T.int64(3)):
+ with T.block("rxplaceholder_red"):
+ v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1,
k0, k2])
+ T.reads(rxplaceholder[v_k0, v_ax0, v_k2])
+ T.writes(rxplaceholder_red[v_ax0, v_ax1])
+ with T.init():
+ rxplaceholder_red[v_ax0, v_ax1] = T.float32(0)
+ rxplaceholder_red[v_ax0, v_ax1] =
(rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2])
+ # fmt: on
+
+ mod = LegalizeOps()(CollapseSumTo)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected]("TOPI collapse_sum not support symbolic now")
+def test_collapse_sum_to_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class CollapseSumTo:
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("b",
1), "float32"):
+ b = T.var("int64")
+ gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1))
+ return gv
+
+ # fmt: on
+
+ mod = LegalizeOps()(CollapseSumTo)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index 27f089ee67..c1d0c90d34 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -310,5 +310,38 @@ def test_squeeze_with_indices():
_check(foo, bb.get()["foo"])
+def test_collapse_sum_like():
+ @R.function
+ def foo(
+ x: R.Tensor((3, 4, 5), "float32"), y: R.Tensor((4, 5), "float32")
+ ) -> R.Tensor((4, 5), "float32"):
+ gv: R.Tensor((4, 5), "float32") = R.collapse_sum_like(x, y)
+ return gv
+
+ x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ y = relax.Var("y", R.Tensor((4, 5), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, y]):
+ gv = bb.emit(relax.op.collapse_sum_like(x, y))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_collapse_sum_to():
+ @R.function
+ def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((4, 5), "float32"):
+ gv: R.Tensor((4, 5), "float32") = R.collapse_sum_to(x, (4, 5))
+ return gv
+
+ x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.collapse_sum_to(x, (4, 5)))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/topi/python/test_topi_reduce.py
b/tests/python/topi/python/test_topi_reduce.py
index e7f47ba0c4..0f585fec96 100644
--- a/tests/python/topi/python/test_topi_reduce.py
+++ b/tests/python/topi/python/test_topi_reduce.py
@@ -26,6 +26,7 @@ import tvm.testing
import tvm.topi.testing
from tvm import te, topi
+from tvm.topi.utils import get_const_tuple
in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
((32,), 0, False, "argmax", "float32"),
@@ -183,5 +184,43 @@ def test_complex_reduce(target, dev):
tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
+data_shape, target_shape = tvm.testing.parameters(
+ ((2, 3), (3,)),
+ ((2, 3, 4), (2, 1, 4)),
+ ((2, 3, 4, 5), (3, 1, 5)),
+)
+
+
+def _my_npy_collapse_sum(data, target_shape):
+ reduce_axes = []
+ i = data.ndim - 1
+ j = len(target_shape) - 1
+ while i >= 0:
+ if j < 0:
+ reduce_axes.append(i)
+ elif target_shape[j] == 1 and data.shape[i] > 1:
+ reduce_axes.append(i)
+ i -= 1
+ j -= 1
+ return np.sum(data, tuple(reduce_axes)).reshape(target_shape)
+
+
+def test_collapse_sum(data_shape, target_shape):
+ A = te.placeholder(data_shape, name="A")
+ B = topi.collapse_sum(A, target_shape)
+ s = te.create_schedule([B.op])
+
+ a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
+ b_np = _my_npy_collapse_sum(a_np, target_shape)
+ dev = tvm.cpu(0)
+ a = tvm.nd.array(a_np, dev)
+ b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+ # Building with the CSE pass disabled
+ with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
+ foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")
+ foo(a, b)
+ tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()