This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ec9e0a0655 [Unity] Allow FLegalize to produce Relax operations (#15842)
ec9e0a0655 is described below

commit ec9e0a06557fdce52dc665dda19b48ea8408804f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 19 14:29:22 2023 -0500

    [Unity] Allow FLegalize to produce Relax operations (#15842)
    
    * [Unity] Allow FLegalize to produce Relax operations
    
    Prior to this commit, a `FLegalize` function needed to produce an
    implementation that can be used as input by
    `relax.transform.AnnotateTIROpPattern`, and could not lower to other
    relax operations.  This commit allows Relax operations to be included
    in the output of `FLegalize`, with the result being further legalized
    if required.
    
    * Maintain binding block type for nested SeqExpr
    
    * Avoid infinite recursion for strided slice on dynamic axis
    
    * Avoid duplicate variables when checking for re-legalization
    
    * Collect bindings to legalize during normalization
---
 src/relax/ir/block_builder.cc                     | 12 +++-
 src/relax/transform/legalize_ops.cc               | 80 ++++++++++++++++-------
 tests/python/relax/test_transform_legalize_ops.py | 74 +++++++++++++++++++++
 3 files changed, 141 insertions(+), 25 deletions(-)

diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index cc79d45323..5037161fcb 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -612,7 +612,17 @@ class Normalizer : public BlockBuilderImpl, private 
ExprFunctor<Expr(const Expr&
       unchanged &= new_block.same_as(block);
     }
 
-    this->BeginBindingBlock();
+    // Because the input may not be normalized, the SeqExpr may occur
+    // nested within another SeqExpr.  In that case, we want to use
+    // whatever binding-block type the parent uses, so that we any
+    // bindings collected into the prologue will be compatible with
+    // the parent block.
+    if (block_stack_.size() && CurrentBlockIsDataFlow()) {
+      this->BeginDataflowBlock();
+    } else {
+      this->BeginBindingBlock();
+    }
+
     // the body may not be a leaf expression, so check for that
     Expr new_body = this->NormalizeArgument(op->body);
     unchanged &= new_body.same_as(op->body);
diff --git a/src/relax/transform/legalize_ops.cc 
b/src/relax/transform/legalize_ops.cc
index 4469f35585..170967d282 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -57,10 +57,11 @@ class LegalizeMutator : public ExprMutator {
  public:
   explicit LegalizeMutator(const IRModule& mod, const Optional<Map<String, 
PackedFunc>>& cmap,
                            bool enable_warning)
-      : ExprMutator(mod),
-        mod_(std::move(mod)),
-        cmap_(std::move(cmap)),
-        enable_warning_(enable_warning) {}
+      : ExprMutator(mod), mod_(std::move(mod)), 
enable_warning_(enable_warning) {
+    if (cmap) {
+      cmap_ = std::move(cmap.value());
+    }
+  }
 
   IRModule Transform() {
     for (const auto& [gv, func] : mod_->functions) {
@@ -132,36 +133,67 @@ class LegalizeMutator : public ExprMutator {
       return visited_call;
     }
 
-    // Priority: customize > default.
-    // Check if it has customize legalization registered.
-    if (cmap_.defined() && cmap_.value().count(op->name)) {
-      auto ret = cmap_.value()[op->name](this->builder_, visited_call);
-      if (ret.IsObjectRef<Expr>() && WrapPureCondition(op, 
ret.AsObjectRef<Expr>())) {
-        return WrapPureCall(Downcast<Call>(ret.AsObjectRef<Expr>()));
+    FLegalize legalization_func;
+
+    if (auto opt_custom_legalize = cmap_.Get(op->name)) {
+      // First choice, use a custom legalization function
+      legalization_func = opt_custom_legalize.value();
+    } else if (legalize_map.count(op)) {
+      // Second choice, use a default legalization
+      legalization_func = legalize_map[op];
+    } else {
+      // No legalization.
+      if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
+          op != call_pure_packed_op) {
+        LOG(WARNING) << "No legalization func for " << op->name << " is 
found.";
       }
-      return ret;
+      return visited_call;
     }
-    // Check if it has default legalization registered.
-    if (legalize_map.count(op)) {
-      auto ret = legalize_map[op](this->builder_, visited_call);
-      if (WrapPureCondition(op, ret)) {
-        return WrapPureCall(Downcast<Call>(ret));
-      }
-      return ret;
+
+    // The legalization function may call `builder_->Emit()` as part
+    // of its implementation.  In that case, any operations it emits
+    // must be caught such that they be checked for recursive
+    // legalization.  This is done by wrapping the legalized value in
+    // a SeqExpr, which can first be visited, then unwrapped by the
+    // normalization.
+    if (builder_->CurrentBlockIsDataFlow()) {
+      builder_->BeginDataflowBlock();
+    } else {
+      builder_->BeginBindingBlock();
     }
+    Expr legalized = legalization_func(builder_, visited_call);
+    legalized = builder_->Normalize(legalized);
 
-    // No legalization.
-    if (enable_warning_ && op != call_tir_op && op != call_dps_packed_op &&
-        op != call_pure_packed_op) {
-      LOG(WARNING) << "No legalization func for " << op->name << " is found.";
+    BindingBlock prologue = builder_->EndBlock();
+    for (const auto& binding : prologue->bindings) {
+      VisitBinding(binding);
     }
-    return visited_call;
+
+    if (WrapPureCondition(op, legalized)) {
+      legalized = WrapPureCall(Downcast<Call>(legalized));
+    }
+
+    // Legalization may have introduced additional operations that
+    // must be legalized as well.  For example, a user-custom
+    // intrinsic whose legalization is implemented in terms of relax
+    // intrinsics.  The base case of the recursion occurs when no
+    // additional legalization steps are found.
+    //
+    // Only perform recursive legalization when the legalization
+    // function returned a modified expression, as some legalizations
+    // return the original expression if they are unable to produce a
+    // legalized version.
+    if (!legalized.same_as(visited_call)) {
+      legalized = VisitExpr(legalized);
+    }
+
+    return legalized;
   }
 
   /*! \brief The context IRModule. */
   IRModule mod_;
   /*! \brief The customized legalization function map. */
-  Optional<Map<String, PackedFunc>> cmap_;
+  Map<String, PackedFunc> cmap_;
   /*!
    * \brief A boolean value indicating if to print warnings for CallNode whose 
op's
    * legalization function is not registered.
diff --git a/tests/python/relax/test_transform_legalize_ops.py 
b/tests/python/relax/test_transform_legalize_ops.py
index af6004bd0a..47eeb68341 100644
--- a/tests/python/relax/test_transform_legalize_ops.py
+++ b/tests/python/relax/test_transform_legalize_ops.py
@@ -24,6 +24,8 @@ from tvm.relax.transform.legalize_ops.common import 
register_legalize
 from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
 
+import pytest
+
 
 def test_customize_legalize():
     # fmt: off
@@ -282,5 +284,77 @@ def test_matmul_legalization_requires_known_dtype():
     assert err_message.startswith("To legalize R.matmul")
 
 
+emit_legalization_through_builder = tvm.testing.parameter(
+    by_dict={
+        "return_relax_expr": False,
+        "return_relax_var": True,
+    }
+)
+
+
[email protected]
+def custom_op(emit_legalization_through_builder):
+    op_name = "custom_op.matmul_bias_add"
+
+    def infer_struct_info(call: relax.Call, context):
+        activations, weight, bias = call.args
+
+        matmul_call = relax.op.matmul(activations, weight)
+        matmul_sinfo = 
tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
+            matmul_call, context
+        )
+
+        matmul_var = relax.Var("dummy_var", matmul_sinfo)
+        add_call = matmul_var + bias
+        add_sinfo = 
tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
+
+        return add_sinfo
+
+    def legalize(bb: relax.BlockBuilder, call: relax.Call):
+        activations, weight, bias = call.args
+        legalized = relax.op.matmul(activations, weight) + bias
+        if emit_legalization_through_builder:
+            legalized = bb.emit(legalized)
+        return legalized
+
+    op_attrs = {
+        "FInferStructInfo": infer_struct_info,
+        "FLegalize": legalize,
+        "FPurity": True,
+    }
+
+    for key, value in op_attrs.items():
+        tvm.ir.register_op_attr(op_name, key, value)
+
+    op = tvm.ir.Op.get(op_name)
+    yield op
+
+    for key in op_attrs:
+        op.reset_attr(key)
+
+
+def test_recursive_legalization(custom_op):
+    """Legalization of an operator may produce new operators requiring 
legalization"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            A: R.Tensor([16, 32, 64], "float32"),
+            Weight: R.Tensor([64, 128], "float32"),
+            Bias: R.Tensor([16, 32, 128], "float32"),
+        ):
+            return relax.Call(custom_op, [A, Weight, Bias])
+
+    AfterFirstIter = LegalizeOps()(Before)
+    AfterSecondIter = LegalizeOps()(AfterFirstIter)
+
+    # After LegalizeOps, the custom operation should be replaced by
+    # `R.matmul` and `R.add`, which should in turn be replaced with
+    # TIR implementations.  Therefore, the second application of
+    # LegalizeOps() should be a no-op.
+    tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to