This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ec9e0a0655 [Unity] Allow FLegalize to produce Relax operations (#15842)
ec9e0a0655 is described below
commit ec9e0a06557fdce52dc665dda19b48ea8408804f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 19 14:29:22 2023 -0500
[Unity] Allow FLegalize to produce Relax operations (#15842)
* [Unity] Allow FLegalize to produce Relax operations
Prior to this commit, a `FLegalize` function needed to produce an
implementation that can be used as input by
`relax.transform.AnnotateTIROpPattern`, and could not lower to other
relax operations. This commit allows Relax operations to be included
in the output of `FLegalize`, with the result being further legalized
if required.
* Maintain binding block type for nested SeqExpr
* Avoid infinite recursion for strided slice on dynamic axis
* Avoid duplicate variables when checking for re-legalization
* Collect bindings to legalize during normalization
---
src/relax/ir/block_builder.cc | 12 +++-
src/relax/transform/legalize_ops.cc | 80 ++++++++++++++++-------
tests/python/relax/test_transform_legalize_ops.py | 74 +++++++++++++++++++++
3 files changed, 141 insertions(+), 25 deletions(-)
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index cc79d45323..5037161fcb 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -612,7 +612,17 @@ class Normalizer : public BlockBuilderImpl, private
ExprFunctor<Expr(const Expr&
unchanged &= new_block.same_as(block);
}
- this->BeginBindingBlock();
+ // Because the input may not be normalized, the SeqExpr may occur
+ // nested within another SeqExpr. In that case, we want to use
+ // whatever binding-block type the parent uses, so that we any
+ // bindings collected into the prologue will be compatible with
+ // the parent block.
+ if (block_stack_.size() && CurrentBlockIsDataFlow()) {
+ this->BeginDataflowBlock();
+ } else {
+ this->BeginBindingBlock();
+ }
+
// the body may not be a leaf expression, so check for that
Expr new_body = this->NormalizeArgument(op->body);
unchanged &= new_body.same_as(op->body);
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index 4469f35585..170967d282 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -57,10 +57,11 @@ class LegalizeMutator : public ExprMutator {
public:
explicit LegalizeMutator(const IRModule& mod, const Optional<Map<String,
PackedFunc>>& cmap,
bool enable_warning)
- : ExprMutator(mod),
- mod_(std::move(mod)),
- cmap_(std::move(cmap)),
- enable_warning_(enable_warning) {}
+ : ExprMutator(mod), mod_(std::move(mod)),
enable_warning_(enable_warning) {
+ if (cmap) {
+ cmap_ = std::move(cmap.value());
+ }
+ }
IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
@@ -132,36 +133,67 @@ class LegalizeMutator : public ExprMutator {
return visited_call;
}
- // Priority: customize > default.
- // Check if it has customize legalization registered.
- if (cmap_.defined() && cmap_.value().count(op->name)) {
- auto ret = cmap_.value()[op->name](this->builder_, visited_call);
- if (ret.IsObjectRef<Expr>() && WrapPureCondition(op,
ret.AsObjectRef<Expr>())) {
- return WrapPureCall(Downcast<Call>(ret.AsObjectRef<Expr>()));
+ FLegalize legalization_func;
+
+ if (auto opt_custom_legalize = cmap_.Get(op->name)) {
+ // First choice, use a custom legalization function
+ legalization_func = opt_custom_legalize.value();
+ } else if (legalize_map.count(op)) {
+ // Second choice, use a default legalization
+ legalization_func = legalize_map[op];
+ } else {
+ // No legalization.
+ if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
+ op != call_pure_packed_op) {
+ LOG(WARNING) << "No legalization func for " << op->name << " is
found.";
}
- return ret;
+ return visited_call;
}
- // Check if it has default legalization registered.
- if (legalize_map.count(op)) {
- auto ret = legalize_map[op](this->builder_, visited_call);
- if (WrapPureCondition(op, ret)) {
- return WrapPureCall(Downcast<Call>(ret));
- }
- return ret;
+
+ // The legalization function may call `builder_->Emit()` as part
+ // of its implementation. In that case, any operations it emits
+ // must be caught such that they be checked for recursive
+ // legalization. This is done by wrapping the legalized value in
+ // a SeqExpr, which can first be visited, then unwrapped by the
+ // normalization.
+ if (builder_->CurrentBlockIsDataFlow()) {
+ builder_->BeginDataflowBlock();
+ } else {
+ builder_->BeginBindingBlock();
}
+ Expr legalized = legalization_func(builder_, visited_call);
+ legalized = builder_->Normalize(legalized);
- // No legalization.
- if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
- op != call_pure_packed_op) {
- LOG(WARNING) << "No legalization func for " << op->name << " is found.";
+ BindingBlock prologue = builder_->EndBlock();
+ for (const auto& binding : prologue->bindings) {
+ VisitBinding(binding);
}
- return visited_call;
+
+ if (WrapPureCondition(op, legalized)) {
+ legalized = WrapPureCall(Downcast<Call>(legalized));
+ }
+
+ // Legalization may have introduced additional operations that
+ // must be legalized as well. For example, a user-custom
+ // intrinsic whose legalization is implemented in terms of relax
+ // intrinsics. The base case of the recursion occurs when no
+ // additional legalization steps are found.
+ //
+ // Only perform recursive legalization when the legalization
+ // function returned a modified expression, as some legalizations
+ // return the original expression if they are unable to produce a
+ // legalized version.
+ if (!legalized.same_as(visited_call)) {
+ legalized = VisitExpr(legalized);
+ }
+
+ return legalized;
}
/*! \brief The context IRModule. */
IRModule mod_;
/*! \brief The customized legalization function map. */
- Optional<Map<String, PackedFunc>> cmap_;
+ Map<String, PackedFunc> cmap_;
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose
op's
* legalization function is not registered.
diff --git a/tests/python/relax/test_transform_legalize_ops.py
b/tests/python/relax/test_transform_legalize_ops.py
index af6004bd0a..47eeb68341 100644
--- a/tests/python/relax/test_transform_legalize_ops.py
+++ b/tests/python/relax/test_transform_legalize_ops.py
@@ -24,6 +24,8 @@ from tvm.relax.transform.legalize_ops.common import
register_legalize
from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
+import pytest
+
def test_customize_legalize():
# fmt: off
@@ -282,5 +284,77 @@ def test_matmul_legalization_requires_known_dtype():
assert err_message.startswith("To legalize R.matmul")
+emit_legalization_through_builder = tvm.testing.parameter(
+ by_dict={
+ "return_relax_expr": False,
+ "return_relax_var": True,
+ }
+)
+
+
[email protected]
+def custom_op(emit_legalization_through_builder):
+ op_name = "custom_op.matmul_bias_add"
+
+ def infer_struct_info(call: relax.Call, context):
+ activations, weight, bias = call.args
+
+ matmul_call = relax.op.matmul(activations, weight)
+ matmul_sinfo =
tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
+ matmul_call, context
+ )
+
+ matmul_var = relax.Var("dummy_var", matmul_sinfo)
+ add_call = matmul_var + bias
+ add_sinfo =
tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
+
+ return add_sinfo
+
+ def legalize(bb: relax.BlockBuilder, call: relax.Call):
+ activations, weight, bias = call.args
+ legalized = relax.op.matmul(activations, weight) + bias
+ if emit_legalization_through_builder:
+ legalized = bb.emit(legalized)
+ return legalized
+
+ op_attrs = {
+ "FInferStructInfo": infer_struct_info,
+ "FLegalize": legalize,
+ "FPurity": True,
+ }
+
+ for key, value in op_attrs.items():
+ tvm.ir.register_op_attr(op_name, key, value)
+
+ op = tvm.ir.Op.get(op_name)
+ yield op
+
+ for key in op_attrs:
+ op.reset_attr(key)
+
+
+def test_recursive_legalization(custom_op):
+ """Legalization of an operator may produce new operators requiring
legalization"""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ A: R.Tensor([16, 32, 64], "float32"),
+ Weight: R.Tensor([64, 128], "float32"),
+ Bias: R.Tensor([16, 32, 128], "float32"),
+ ):
+ return relax.Call(custom_op, [A, Weight, Bias])
+
+ AfterFirstIter = LegalizeOps()(Before)
+ AfterSecondIter = LegalizeOps()(AfterFirstIter)
+
+ # After LegalizeOps, the custom operation should be replaced by
+ # `R.matmul` and `R.add`, which should in turn be replaced with
+ # TIR implementations. Therefore, the second application of
+ # LegalizeOps() should be a no-op.
+ tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)
+
+
if __name__ == "__main__":
tvm.testing.main()