This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 675a22e080 [Unity][Transform] Introduce data-dependent operation of
reshape and its constant folding (#14282)
675a22e080 is described below
commit 675a22e080201c33217f5e35a305dc5a90878cf6
Author: Sunghyun Park <[email protected]>
AuthorDate: Tue Mar 21 13:33:56 2023 -0700
[Unity][Transform] Introduce data-dependent operation of reshape and its
constant folding (#14282)
* FEAT: Support data-dependent operation of reshape
* FEAT: Support constant folding with data-dependent reshape
* fix
* remove empty line
* reflect feedback
* Lift the lowering of tensor_to_shape from builtin to
DecomposeCompositeOps pass
* fix and comment
* fix
* add comments
* reflect feedback
* add comment
* fix
---
include/tvm/relax/transform.h | 7 +-
python/tvm/relax/op/base.py | 14 ++
.../tvm/relax/transform/legalize_ops/manipulate.py | 6 +-
python/tvm/relax/transform/transform.py | 11 +-
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/backend/vm/vm_builtin_lower.cc | 20 ++-
src/relax/op/op.cc | 26 +++
src/relax/op/tensor/manipulate.cc | 7 +-
...orm_inference.cc => decompose_composite_ops.cc} | 51 +++++-
src/relax/transform/fold_constant.cc | 70 +++++++-
src/runtime/relax_vm/builtin.cc | 34 ++++
tests/python/relax/test_op_manipulate.py | 25 +--
...y => test_transform_decompose_composite_ops.py} | 25 ++-
tests/python/relax/test_transform_fold_constant.py | 37 ++++
.../test_transform_legalize_ops_manipulate.py | 195 +++++++++++++++++++++
15 files changed, 493 insertions(+), 37 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index a3d0d4a0e9..4f45ba9c25 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -372,12 +372,13 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String,
ObjectRef>>> target_opt
Array<runtime::String> entry_functions);
/*!
- * \brief Simplify normalization operators during inference. For example, the
result
+ * \brief Decompose composite operators during inference. For example, the
result
* of a batch norm which is indexed at tuple index 0 will be unpacked into a
- * number of simplified operators.
+ * number of simplified operators. Operators like Attention, Erf, etc. can be
also
+ * simplified into several operators as well.
* \return The Pass.
*/
-TVM_DLL Pass SimplifyNormInference();
+TVM_DLL Pass DecomposeCompositeOperator();
/*!
* \brief Returns a pass which replaces PrimFuncs which have matching
kOperatorName attribute in \p
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index aef0e731db..becd3f2a0f 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -402,3 +402,17 @@ def shape_of(expr: Expr) -> Expr:
A relax Call, which gets the shape of the input
"""
return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member
+
+
+def tensor_to_shape(expr: Expr) -> Expr:
+ """Convert tensor to shape expr.
+ Parameters
+ ----------
+ expr : Expr
+ The input Expr
+ Returns
+ -------
+ result : Expr
+ A relax Call, which transforms the tensor values to the shape
+ """
+ return _ffi_api.tensor_to_shape(expr) # type: ignore # pylint:
disable=no-member
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index e7cae1af34..144ef04748 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -23,7 +23,7 @@ import tvm
from tvm import topi, tir, relax, te
from tvm.tir.expr import IntImm
from ...block_builder import BlockBuilder
-from ...expr import Call, Expr, Var, Tuple, TupleGetItem
+from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr
from .common import TEFunc, LegalizeFunc, register_legalize
@@ -32,6 +32,10 @@ def _reshape(
) -> LegalizeFunc:
def reshape_call_te(bb: BlockBuilder, call: Call):
tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like
else call.args[1]
+ # If target shape is Var, pass its bound expr only when it is ShapeExpr
+ if isinstance(tgt_shape, Var):
+ tgt_shape = bb.lookup_binding(tgt_shape)
+ assert isinstance(tgt_shape, ShapeExpr)
return bb.call_te(te_func, call.args[0], tgt_shape,
primfunc_name_hint=primfunc_name)
return reshape_call_te
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index e8e3d73113..c03df804ee 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -588,10 +588,11 @@ def MetaScheduleTuneIRMod(
return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global)
# type: ignore
-def SimplifyNormInference() -> tvm.ir.transform.Pass:
- """Simplify normalization operators during inference. For example, the
result
- of a batch norm which is indexed at tuple index 0 will be unpacked into a
- number of simplified operators.
+def DecomposeCompositeOps() -> tvm.ir.transform.Pass:
+ """Decompose composite operators that are composed by other operators
during inference.
+ For example, the result of a batch norm which is indexed at tuple index 0
will be unpacked
+ into a number of simplified operators. Attention, tensor_to_shape, etc.
can be also
+ decomposed into a number of simplified operators as well.
Returns
-------
@@ -599,7 +600,7 @@ def SimplifyNormInference() -> tvm.ir.transform.Pass:
The registered pass
"""
- return _ffi_api.SimplifyNormInference() # type: ignore
+ return _ffi_api.DecomposeCompositeOps() # type: ignore
def AlterOpImpl(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 32d6083e8a..ae0918a082 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -97,6 +97,7 @@ from tvm.relax.op import (
prod,
repeat,
reshape,
+ tensor_to_shape,
round,
shape_of,
std,
@@ -612,6 +613,7 @@ __all__ = [
"prod",
"repeat",
"reshape",
+ "tensor_to_shape",
"round",
"shape",
"shape_of",
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc
b/src/relax/backend/vm/vm_builtin_lower.cc
index 00d8512dc6..5bf4194997 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -129,9 +129,23 @@ class VMBuiltinLowerMutator : public ExprMutator {
Expr Reshape(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->struct_info_.defined());
- CHECK(call_node->args[1]->IsInstance<ShapeExprNode>())
- << "VMBuiltinLower expects the shape arg of reshape op to be a
ShapeExpr";
- return Call(builtin_reshape_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
+ auto arg = call_node->args[1];
+ CHECK(arg->IsInstance<ShapeExprNode>() || arg->IsInstance<VarNode>())
+ << "VMBuiltinLower expects the shape arg of reshape op to be a
ShapeExpr or VarNode bound "
+ "to a ShapeExpr";
+
+ if (arg->IsInstance<ShapeExprNode>()) {
+ return Call(builtin_reshape_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
+ } else {
+ // Handling the case when arg is VarNode
+ Optional<Expr> _bound_val = LookupBinding(Downcast<Var>(arg));
+ ICHECK(_bound_val.defined());
+ Expr bound_val = _bound_val.value();
+ CHECK(bound_val->IsInstance<ShapeExprNode>())
+ << "VMBuiltinLower expects bound value to be a ShapeExpr";
+ return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(),
+ {GetStructInfo(call_node)});
+ }
}
Expr ShapeOf(const Call& call_node) {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index a603040394..49df881dcb 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -315,6 +315,32 @@ Expr MakeShapeOf(Expr expr) {
TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf);
+// tensor_to_shape
+
+StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder&
ctx) {
+ ICHECK(call->args.size() == 1);
+ ICHECK(call->args[0]->struct_info_.defined());
+ const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tsinfo && tsinfo->shape.defined());
+ ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
+ ICHECK(shape_expr->values.size() == 1);
+ const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
+ ICHECK(ndim);
+ return ShapeStructInfo(ndim->value);
+}
+
+RELAY_REGISTER_OP("relax.tensor_to_shape")
+ .set_num_inputs(1)
+ .add_argument("input", "Expr", "The input expression")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
ReturnTensorToShapeStructInfo);
+
+Expr MakeTensorToShape(Expr expr) {
+ static const Op& op = Op::Get("relax.tensor_to_shape");
+ return Call(op, {expr}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape);
+
// alloc_tensor
StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder&
ctx) {
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index dbeb6f8d5b..faa5ee3bc0 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -751,7 +751,12 @@ StructInfo InferStructInfoReshape(const Call& call, const
BlockBuilder& ctx) {
<< new_shape_prod);
}
}
- return TensorStructInfo(call->args[1], data_sinfo->dtype);
+ Expr target_shape = call->args[1];
+ // If shape values are defined, use them
+ if (target_shape->IsInstance<VarNode>() &&
new_shape_sinfo->values.defined()) {
+ return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()),
data_sinfo->dtype);
+ }
+ return TensorStructInfo(target_shape, data_sinfo->dtype);
}
TVM_REGISTER_OP("relax.reshape")
diff --git a/src/relax/transform/simplify_norm_inference.cc
b/src/relax/transform/decompose_composite_ops.cc
similarity index 69%
rename from src/relax/transform/simplify_norm_inference.cc
rename to src/relax/transform/decompose_composite_ops.cc
index 545098db28..3681442221 100644
--- a/src/relax/transform/simplify_norm_inference.cc
+++ b/src/relax/transform/decompose_composite_ops.cc
@@ -21,6 +21,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include "utils.h"
@@ -110,21 +111,63 @@ class NormInferenceSimplifier : public ExprMutator {
Map<Expr, Expr> batch_norm_map_;
};
+class OpDecomposer : public ExprMutator {
+ public:
+ static Expr Decompose(Expr expr) { return OpDecomposer()(expr); }
+
+ private:
+ using ExprMutator::VisitExpr_;
+ Expr TensorToShape(const Call& call_node) {
+ ICHECK(call_node->struct_info_.defined());
+ Expr expr = call_node->args[0];
+ const ShapeStructInfoNode* sinfo =
GetStructInfoAs<ShapeStructInfoNode>(call_node);
+ ICHECK(sinfo);
+ // call builtin function that converts tensor to shape tuple
+ // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape"
+ Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"),
{expr}, {},
+ {GetRef<ShapeStructInfo>(sinfo)}));
+
+ // Operators like reshape take the output of `TensorToShape` as their
output shape.
+ // Because TOPI expects to have such output shape in symbolic shape at
least (i.e.,
+ // Array<PrimExpr>), we define symbolic variables and returns them as a
ShapeExpr.
+ Array<PrimExpr> shape_var;
+ for (int i = 0; i < sinfo->ndim; i++) {
+ shape_var.push_back(tir::Var("x", DataType::Int(64)));
+ }
+ // bind symbolic variables to the shape tuple
+ relax::Var var("y", ShapeStructInfo(shape_var));
+ builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var)));
+ return ShapeExpr(shape_var);
+ }
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
+ if (call->op == tensor_to_shape_op_) {
+ return TensorToShape(call);
+ } else {
+ return call;
+ }
+ }
+
+ const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
+};
+
namespace transform {
-Pass SimplifyNormInference() {
+Pass DecomposeCompositeOps() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
[=](Function f, IRModule m, PassContext pc) {
f = Downcast<Function>(NormInferenceSimplifier::Simplify(f));
- // Remove original batch_norm op if it's not used.
+ f = Downcast<Function>(OpDecomposer::Decompose(f));
+ // Remove original ops if it's not used.
return RemoveAllUnused(f);
};
return CreateFunctionPass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
- /*pass_name=*/"SimplifyNormInference", //
+ /*pass_name=*/"DecomposeCompositeOps", //
/*required=*/{});
}
-TVM_REGISTER_GLOBAL("relax.transform.SimplifyNormInference").set_body_typed(SimplifyNormInference);
+TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOps").set_body_typed(DecomposeCompositeOps);
} // namespace transform
} // namespace relax
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 622dd9ad09..315b3dc1f2 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -196,6 +196,10 @@ class ConstantFolder : public ExprMutator {
using ExprMutator::VisitExpr_;
+ // TODO(@sunggg):
+ // Next PR will support fold with PackedFunc and MatchCast
+ // Until then, DecomposeCompositeOps() should be applied after
+ // this pass to fold `tensor_to_shape` op.
Expr VisitExpr_(const CallNode* call) final {
// post-order mutation
Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
@@ -217,14 +221,64 @@ class ConstantFolder : public ExprMutator {
return VisitCallTIR(post_call).value_or(post_call);
}
- // If we are in a dataflow block, we can fold ops by lowering them to
call_tir.
- if (builder_->CurrentBlockIsDataFlow() && legalize_map.count(op)) {
- // Get the legalized expression
- Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_,
post_call));
- // If the legalized expression is call_tir, try to fold it.
- const CallNode* call = legalized_expr.as<CallNode>();
- if (call && call->op.same_as(call_tir_op)) {
- return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
+ // Special logic to fold ShapeExpr between operators
+ // e.g.,
+ // <Before>
+ // lv: R.Shape([16, 16]) = R.shape([16, 16])
+ // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, lv)
+ // <After>
+ // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16,
16]))
+ //
+ Array<Expr> new_args;
+ for (auto arg : post_call->args) {
+ if (arg->IsInstance<VarNode>()) {
+ Optional<Expr> val = LookupBinding(Downcast<Var>(arg));
+ if (val.defined() && val.value()->IsInstance<ShapeExprNode>()) {
+ new_args.push_back(val.value());
+ continue;
+ }
+ }
+ new_args.push_back(arg);
+ }
+ post_call =
+ Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args,
post_call->span);
+
+ // If we are in a dataflow block, we can fold ops.
+ if (builder_->CurrentBlockIsDataFlow()) {
+ // Check if we can them to call_tir
+ if (legalize_map.count(op)) {
+ // Get the legalized expression
+ Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_,
post_call));
+ // If the legalized expression is call_tir, try to fold it.
+ const CallNode* call = legalized_expr.as<CallNode>();
+ if (call && call->op.same_as(call_tir_op)) {
+ return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
+ }
+ } else if (op->name == "relax.tensor_to_shape") {
+ // Special handling for composite op "relax.tensor_to_shape"
+ // If its input is constant, we can access its value and create
ShapeExpr
+ // TODO(@sunggg):
+ // currently, we do not have a info map about decomposition.
+ // Thus, this is a temporary solution until we have a consensus about
+ // how to deal with composite ops. One possibility is we register the
+ // decomposition map for each op in a similar way we do for
legalization.
+ ICHECK_EQ(post_call->args.size(), 1);
+ Expr arg = post_call->args[0];
+ if (arg->IsInstance<ConstantNode>()) {
+ Constant constant = Downcast<Constant>(arg);
+ runtime::NDArray ndarray = constant->data;
+ ICHECK_EQ(ndarray->device.device_type, kDLCPU);
+ ICHECK(ndarray->strides == nullptr);
+ ICHECK_EQ(ndarray->byte_offset, 0);
+ ICHECK_EQ(ndarray->ndim, 1);
+ const int64_t* data = static_cast<const int64_t*>(ndarray->data);
+ int64_t num_elems = ndarray->shape[0];
+ Array<PrimExpr> shape_values;
+ for (int64_t i = 0; i < num_elems; i++) {
+ shape_values.push_back(IntImm(DataType::Int(64), data[i]));
+ }
+ return ShapeExpr(shape_values);
+ }
}
}
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 15a4f8702b..5a7c1d6620 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -380,6 +380,40 @@
TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetVal
*rv = arr;
});
+TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray
data) {
+ NDArray arr = data;
+ if (data->device.device_type != kDLCPU) {
+ arr = data.CopyTo(DLDevice{kDLCPU, 0});
+ }
+
+ ICHECK_EQ(arr->ndim, 1);
+ ICHECK_EQ(arr->dtype.code, kDLInt);
+
+ std::vector<int64_t> out_shape;
+ for (int i = 0; i < arr.Shape()[0]; ++i) {
+ int64_t result;
+ switch (arr->dtype.bits) {
+ case 16: {
+ result = reinterpret_cast<int16_t*>(arr->data)[i];
+ break;
+ }
+ case 32: {
+ result = reinterpret_cast<int32_t*>(arr->data)[i];
+ break;
+ }
+ case 64: {
+ result = reinterpret_cast<int64_t*>(arr->data)[i];
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown scalar int type: " <<
DLDataType2String(arr->dtype);
+ throw;
+ }
+ out_shape.push_back(result);
+ }
+ return ShapeTuple(out_shape);
+});
+
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index 16bbc04d26..3edf63764a 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -88,12 +88,13 @@ def test_reshape_infer_struct_info():
_check_inference(
bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5),
dtype="")
)
- _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo(s0,
"float32"))
- _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo(s0,
"float32"))
- _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo(s0,
"float32"))
- _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo(s0,
dtype=""))
- _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo(s0,
dtype=""))
- _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo(s0,
dtype=""))
+ # Remove Var from StructInfo when we can
+ _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo((3,
8, 5), "float32"))
+ _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo((3,
8, 5), "float32"))
+ _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo((3,
8, 5), "float32"))
+ _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo((3,
8, 5), dtype=""))
+ _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo((3,
8, 5), dtype=""))
+ _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo((3,
8, 5), dtype=""))
_check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1,
"float32"))
_check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1,
"float32"))
_check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1,
"float32"))
@@ -160,7 +161,8 @@ def test_reshape_infer_struct_info_shape_symbolic():
(c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))),
"float32"
),
)
- _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0,
"float32"))
+ # Remove Var from StructInfo when we can
+ _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c,
a, d, b), "float32"))
_check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1,
"float32"))
_check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2,
"float32"))
@@ -188,17 +190,20 @@ def test_reshape_infer_struct_info_shape_var():
_check_inference(
bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8,
5), "float32")
)
- _check_inference(bb, relax.op.reshape(x0, ns0),
relax.TensorStructInfo(ns0, "float32"))
+ # Remove Var from StructInfo when we can
+ _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo((3,
8, 5), "float32"))
_check_inference(bb, relax.op.reshape(x0, ns1),
relax.TensorStructInfo(ns1, "float32"))
_check_inference(
bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5),
"float32")
)
- _check_inference(bb, relax.op.reshape(x1, ns0),
relax.TensorStructInfo(ns0, "float32"))
+ # Remove Var from StructInfo when we can
+ _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo((3,
8, 5), "float32"))
_check_inference(bb, relax.op.reshape(x1, ns1),
relax.TensorStructInfo(ns1, "float32"))
_check_inference(
bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5),
"float32")
)
- _check_inference(bb, relax.op.reshape(x2, ns0),
relax.TensorStructInfo(ns0, "float32"))
+ # Remove Var from StructInfo when we can
+ _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo((3,
8, 5), "float32"))
_check_inference(bb, relax.op.reshape(x2, ns1),
relax.TensorStructInfo(ns1, "float32"))
diff --git a/tests/python/relax/test_transform_simpilify_norm_inference.py
b/tests/python/relax/test_transform_decompose_composite_ops.py
similarity index 86%
rename from tests/python/relax/test_transform_simpilify_norm_inference.py
rename to tests/python/relax/test_transform_decompose_composite_ops.py
index 3c981ba035..08483600a3 100644
--- a/tests/python/relax/test_transform_simpilify_norm_inference.py
+++ b/tests/python/relax/test_transform_decompose_composite_ops.py
@@ -22,7 +22,7 @@ import tvm.script
import tvm.testing
from tvm import IRModule, relax
from tvm.relax import Function
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
def _check(before: Union[Function, IRModule], expected: Union[Function,
IRModule]):
@@ -30,7 +30,7 @@ def _check(before: Union[Function, IRModule], expected:
Union[Function, IRModule
before = IRModule({"main": before})
if isinstance(expected, Function):
expected = IRModule({"main": expected})
- after = relax.transform.SimplifyNormInference()(before)
+ after = relax.transform.DecomposeCompositeOps()(before)
tvm.ir.assert_structural_equal(expected, after)
@@ -149,5 +149,26 @@ def test_batch_norm_complex():
_check(before, expected)
+def test_op_tensor_to_shape():
+ @R.function
+ def before(t: R.Tensor(ndim=1, dtype="int64")):
+ gv: R.Shape(ndim=3) = R.tensor_to_shape(t)
+ return gv
+
+ @R.function
+ def expected(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3):
+ x = T.int64()
+ x_1 = T.int64()
+ x_2 = T.int64()
+ gv: R.Shape(ndim=3) = R.call_packed(
+ "vm.builtin.tensor_to_shape", t, sinfo_args=(R.Shape(ndim=3),)
+ )
+ y: R.Shape([x, x_1, x_2]) = R.match_cast(gv, R.Shape([x, x_1, x_2]))
+ gv_1: R.Shape([x, x_1, x_2]) = R.shape([x, x_1, x_2])
+ return gv_1
+
+ _check(before, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index 5bf2d3d9ab..ebd4348b64 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -349,6 +349,43 @@ def test_do_not_fold_ops_outside_dataflow():
tvm.ir.assert_structural_equal(after, before)
+def test_fold_multiple_relax_ops_with_data_dependent_reshape():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def before(
+ data: R.Tensor((256,), "float32"),
+ c0: R.Tensor((2,), "int64"),
+ c1: R.Tensor((2,), "int64"),
+ ):
+ with R.dataflow():
+ lv0 = R.add(c0, c0)
+ target_shape = R.multiply(lv0, c1)
+ lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape)
+ gv: R.Tensor(ndim=2, dtype="float32") = R.reshape(data, lv2)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16),
dtype="float32"):
+ R.func_attr({"global_symbol": "main"})
+ with R.dataflow():
+ gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data,
R.shape([16, 16]))
+ R.output(gv)
+ return gv
+
+ c0_np = [8, 8]
+ c1_np = [1, 1]
+ before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np})
+ assert relax.analysis.well_formed(before)
+
+ c2_np = np.multiply(np.add(c0_np, c0_np), c1_np)
+ expected = gen_mod(Module, "expected", {"c2": c2_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
def test_unsupported_fold_ops_legalized_to_multiple_calls():
@tvm.script.ir_module
class Module:
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index b50ba91089..cce35a9026 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -17,6 +17,7 @@
import pytest
import tvm
+from tvm import relax
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
@@ -498,6 +499,55 @@ def test_reshape():
mod = LegalizeOps()(Reshape)
tvm.ir.assert_structural_equal(mod, Expected)
+ # fmt: off
+ # ShapeExpr might be produced by shape computation
+ @tvm.script.ir_module
+ class Reshape2:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3),
"float32"):
+ lv: R.Shape((8, 3)) = R.shape((8, 3))
+ gv: R.Tensor((8, 3), "float32") = R.reshape(x, lv)
+ return gv
+
+ # After lowering, redundant var might be removed by later dead code
elimination
+ @tvm.script.ir_module
+ class Expected2:
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3),
T.int64(4)), "float32"),
+ T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(8), T.int64(3)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(
+ rxplaceholder[
+ T.int64(0),
+ (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) //
T.int64(12),
+ (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) //
T.int64(4),
+ (v_ax0 * T.int64(3) + v_ax1) % T.int64(4),
+ ]
+ )
+ T.writes(T_reshape[v_ax0, v_ax1])
+ T_reshape[v_ax0, v_ax1] = rxplaceholder[
+ T.int64(0),
+ (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) //
T.int64(12),
+ (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) //
T.int64(4),
+ (v_ax0 * T.int64(3) + v_ax1) % T.int64(4),
+ ]
+
+ @R.function
+ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((8,
3), dtype="float32"):
+ lv: R.Shape((8, 3)) = R.shape((8, 3))
+ gv = R.call_tir(Expected2.reshape, (x,), out_sinfo=R.Tensor((8,
3), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod2 = LegalizeOps()(Reshape2)
+ tvm.ir.assert_structural_equal(mod2, Expected2)
+
def test_reshape_symbolic():
# fmt: off
@@ -537,6 +587,151 @@ def test_reshape_symbolic():
mod = LegalizeOps()(Reshape)
tvm.ir.assert_structural_equal(mod, Expected)
+ # ShapeExpr might be produced by shape computation
+ @tvm.script.ir_module
+ class Reshape2:
+ @R.function
+ def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b
* 2"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2))
+ gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, lv)
+ return gv
+
+ # After lowering, redundant var might be removed by later dead code
elimination
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b
* 2"), "float32"):
+ a = T.int64()
+ b = T.int64()
+ lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2))
+ gv = R.call_tir(Expected2.reshape, (x,), R.Tensor(((a // 2), (b *
2)), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle):
+ T.func_attr({"tir.noalias": True})
+ a = T.int64()
+ b = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b],
dtype="float32")
+ T_reshape = T.match_buffer(
+ var_T_reshape, [a // T.int64(2), b * T.int64(2)],
dtype="float32"
+ )
+ for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)):
+ with T.block("T_reshape"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(
+ rxplaceholder[
+ (ax0 * (b * T.int64(2)) + ax1) // b % a,
+ (ax0 * (b * T.int64(2)) + ax1) % b,
+ ]
+ )
+ T.writes(T_reshape[ax0, ax1])
+ T_reshape[ax0, ax1] = rxplaceholder[
+ (ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b *
T.int64(2)) + ax1) % b
+ ]
+
+ mod2 = LegalizeOps()(Reshape2)
+ tvm.ir.assert_structural_equal(mod2, Expected2)
+
+ # ShapeExpr might be produced by shape computation
+ @I.ir_module
+ class Reshape3:
+ @R.function
+ def main(x: R.Tensor((10, "b"), "float32")) -> R.Tensor((5, "b * 2"),
"float32"):
+ a = T.int64()
+ b = T.int64()
+ lv: R.Shape((5, b * 2)) = R.shape((5, b * 2))
+ gv: R.Tensor((5, b * 2), "float32") = R.reshape(x, lv)
+ return gv
+
+ # After lowering, redundant var might be removed by later dead code
elimination
+ @I.ir_module
+ class Expected3:
+ @T.prim_func
+ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle):
+ T.func_attr({"tir.noalias": True})
+ b = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(10), b))
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(5), b *
T.int64(2)))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(5), b * T.int64(2)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(
+ rxplaceholder[
+ (v_ax0 * (b * T.int64(2)) + v_ax1) // b %
T.int64(10),
+ (v_ax0 * (b * T.int64(2)) + v_ax1) % b,
+ ]
+ )
+ T.writes(T_reshape[v_ax0, v_ax1])
+ T_reshape[v_ax0, v_ax1] = rxplaceholder[
+ (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10),
+ (v_ax0 * (b * T.int64(2)) + v_ax1) % b,
+ ]
+
+ @R.function
+ def main(
+ x: R.Tensor((10, "b"), dtype="float32")
+ ) -> R.Tensor((5, "b * 2"), dtype="float32"):
+ b = T.int64()
+ lv: R.Shape([5, b * 2]) = R.shape([5, b * 2])
+ gv = R.call_tir(
+ Expected3.reshape, (x,), out_sinfo=R.Tensor((5, b * 2),
dtype="float32")
+ )
+ return gv
+
+ mod3 = LegalizeOps()(Reshape3)
+ tvm.ir.assert_structural_equal(mod3, Expected3)
+
+
+def test_data_dependent_reshape():
+ # fmt: off
+ @tvm.script.ir_module
+ class DDReshape:
+ @R.function
+ def main(x: R.Tensor((3, ), dtype="int64")):
+ lv: R.Shape([3,]) = R.tensor_to_shape(x)
+ gv = R.reshape(x, lv)
+ return gv
+ # fmt: on
+
+ assert relax.analysis.well_formed(DDReshape)
+ mod = relax.transform.DecomposeCompositeOps()(DDReshape)
+ out_mod = relax.transform.LegalizeOps()(mod)
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def reshape(
+ rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape:
T.handle
+ ):
+ T.func_attr({"tir.noalias": True})
+ x = T.int64()
+ T_reshape = T.match_buffer(var_T_reshape, (x,), "int64")
+ # with T.block("root"):
+ for ax0 in range(x):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(x, ax0)
+ T.reads(rxplaceholder[v_ax0 % T.int64(3)])
+ T.writes(T_reshape[v_ax0])
+ T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)]
+
+ @R.function
+ def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,),
dtype="int64"):
+ x_1 = T.int64()
+ gv: R.Shape([3]) = R.call_packed(
+ "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)
+ )
+ y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1]))
+ lv: R.Shape([x_1]) = R.shape([x_1])
+ gv_1 = R.call_tir(Expected.reshape, (x,),
out_sinfo=R.Tensor((x_1,), dtype="int64"))
+ return gv_1
+ # fmt: on
+ tvm.ir.assert_structural_equal(out_mod, Expected)
+
def test_split_by_indices():
# fmt: off