This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0ddfc657db [Unity] Implement FNormalize for relax.op.call_tir  (#16068)
0ddfc657db is described below

commit 0ddfc657db410024e7261aa93b6109465c89484a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Nov 14 15:03:09 2023 -0600

    [Unity] Implement FNormalize for relax.op.call_tir  (#16068)
    
    * [Unity] Implement FNormalize for relax.op.call_tir
    
    Prior to this commit, `relax.op.call_tir` could express the TIR
    arguments as either an in-line tuple, or as a variable bound to a
    tuple.  Because several passes assume the arguments will always be an
    in-line tuple, this is being codified as the normal form of
    `relax.op.call_tir`.  Any upstream transform that produces
    `relax.op.call_tir` with arguments provided as a by-variable tuple
    will either be normalized to an in-line tuple if possible, or will
    produce an error during the upstream transform otherwise.
    
    This commit is specifically to allow the current usage of
    `Downcast<Tuple>(call->args[1])` in passes such as `CallTIRRewrite`,
    `FoldConstant`, `FuseTIR`, and `RewriteDataflowReshape`.
    
    * Resolve unit test failures
    
    * Added normalization for call_tir_inplace and call_tir_with_grad
    
    * Normalize arg_tuple to (arg_tuple[0], ..., arg_tuple[N]) if unknown
---
 src/relax/analysis/well_formed.cc                  |  16 +-
 src/relax/op/op.cc                                 |  89 +++++++--
 tests/python/relax/test_transform_cse.py           |  47 +++++
 ...st_transform_operator_specific_normalization.py | 208 ++++++++++++++++++++-
 4 files changed, 341 insertions(+), 19 deletions(-)

diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 5cb577e82b..9eed78d270 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -319,9 +319,19 @@ class WellFormedChecker : public relax::ExprVisitor,
 
     if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); 
func_normalize != nullptr) {
       auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
-      auto before_normalize = GetRef<Call>(call);
-      auto after_normalize = func_normalize(dummy_builder, before_normalize);
-      if (!before_normalize.same_as(after_normalize)) {
+      Call before_normalize = GetRef<Call>(call);
+      Optional<Expr> after_normalize = NullOpt;
+      try {
+        after_normalize = func_normalize(dummy_builder, before_normalize);
+      } catch (std::exception& err) {
+        Malformed(
+            Diagnostic::Error(call)
+            << "If an operator defines an operator-specific normalization 
function (FNormalize), "
+            << "calls to that operator must be normalized with it.  "
+            << "However, normalization of " << before_normalize << " resulted 
in the error: \n"
+            << err.what());
+      }
+      if (after_normalize && !before_normalize.same_as(after_normalize)) {
         Malformed(
             Diagnostic::Error(call)
             << "If an operator defines an operator-specific normalization 
function (FNormalize), "
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index fe74286a51..f51d2cc74f 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -253,11 +253,70 @@ StructInfo InferStructInfoCallTIR(const Call& call, const 
BlockBuilder& ctx) {
   return call->sinfo_args[0];
 }
 
-Expr NormalizeCallTIR(const BlockBuilder&, Call call) {
-  // Temporary implementation to ensure that at least one op has a
-  // registered value for FNormalize.  This temporary implementation
-  // is fully implemented in follow-up PR
-  // https://github.com/apache/tvm/pull/16068.
+Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) {
+  // This function is used for normalization of `relax.call_tir`,
+  // along with the variants `relax.call_tir_with_grad` and
+  // `relax.call_tir_inplace`.  Therefore, all error messages should
+  // be written in terms of `call->op`, and should not explicitly
+  // reference the `relax.call_tir` operator.`
+  CHECK(call->args.size() == 2 || call->args.size() == 3)
+      << "Operation " << call->op << " expects either two arguments [callee, 
arg_tuple], "
+      << "or three arguments [callee, arg_tuple, tir_args], "
+      << "but " << call << " has " << call->args.size() << " arguments.";
+
+  Expr arg_expr = call->args[1];
+
+  CHECK(arg_expr->struct_info_.as<TupleStructInfoNode>())
+      << "Operation " << call->op << " expects the second argument to be a 
tuple of relax Expr.  "
+      << "However, the second argument " << arg_expr << " has struct info "
+      << arg_expr->struct_info_ << ".";
+
+  if (arg_expr.as<TupleNode>()) {
+    return std::move(call);
+  }
+
+  CHECK(arg_expr.as<VarNode>())
+      << "Operation " << call->op << " must hold its arguments as an in-line 
tuple.  "
+      << "However, " << call << " has arguments " << arg_expr
+      << ", which is neither an in-line tuple, "
+      << "nor a variable binding that may be normalized to an in-line tuple.";
+
+  auto unwrap_binding = [&ctx](Expr expr) -> Optional<Expr> {
+    if (auto var = expr.as<Var>()) {
+      if (auto bound_value = ctx->LookupBinding(var.value())) {
+        return bound_value.value();
+      }
+    }
+    return NullOpt;
+  };
+
+  while (auto unwrapped = unwrap_binding(arg_expr)) {
+    arg_expr = unwrapped.value();
+  }
+
+  Tuple new_arg_expr = [&]() {
+    // Preferred replacement.  The argument tuple is provided as a
+    // variable, but we know the value bound to that variable.
+    if (auto opt = arg_expr.as<Tuple>()) {
+      return opt.value();
+    }
+
+    // Fallback case.  The argument tuple is provided as a variable,
+    // and we don't know the value bound to that variable.  For
+    // example, if a relax function accepted a tuple as an parameter,
+    // then provided that same tuple as an argument to call_tir.
+    Array<Expr> tuple_elements;
+    size_t num_fields = 
Downcast<TupleStructInfo>(arg_expr->struct_info_)->fields.size();
+    for (size_t i = 0; i < num_fields; i++) {
+      tuple_elements.push_back(TupleGetItem(arg_expr, i));
+    }
+    return Tuple(tuple_elements);
+  }();
+
+  auto new_args = call->args;
+  new_args.Set(1, new_arg_expr);
+  call.CopyOnWrite()->args = new_args;
+
   return std::move(call);
 }
 
@@ -314,6 +373,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
                   "ShapeExpr representing a tuple of ints to unpack during 
runtime. Omitted from "
                   "args if unused")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
+    .set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
     .set_attr<Bool>("FPurity", Bool(true));
 
 Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> 
out_sinfo_list,
@@ -353,14 +413,12 @@ 
TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWit
 
 // call_tir_inplace
 
-StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& 
ctx) {
-  if (call->sinfo_args.size() != 1) {
-    ctx->ReportFatal(Diagnostic::Error(call)
-                     << "sinfo_args should have exactly 1 output struct 
info.");
-  }
-  CHECK(call->args[0]->IsInstance<GlobalVarNode>())
-      << "call_tir expects the first argument to be a GlobalVar referring to a 
TIR PrimFunc. "
-      << "However, gets " << call->args[0];
+Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
+  // Apply normalization before error checks.  This allows the error
+  // checks to safely apply `Downcast<Tuple>(call->args[1])`, which
+  // may result in an error if performed before normalization.
+  call = Downcast<Call>(NormalizeCallTIR(ctx, std::move(call)));
+
   // there must be an inplace index for each output
   const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
   size_t num_outputs = 1U;
@@ -443,7 +501,7 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, 
const BlockBuilder& c
     }
   }
 
-  return call->sinfo_args[0];
+  return std::move(call);
 }
 
 TVM_REGISTER_NODE_TYPE(CallTIRInplaceAttrs);
@@ -456,7 +514,8 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
     .add_argument("packed_ints", "Expr",
                   "ShapeExpr representing a tuple of ints to unpack during 
runtime. Omitted from "
                   "args if unused")
-    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCallTIRInplace)
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
+    .set_attr<FNormalize>("FNormalize", NormalizeCallTIRInPlace)
     // Warning: considered pure, but it has the potential to create visible 
effects!
     // This should only be used if it has been *checked* that it is safe (no 
aliases, in-place
     // arguments will no longer be live)
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index 92cf4349d4..3a57afb22c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -292,5 +292,52 @@ def test_do_not_eliminate_extern_func():
     verify(Before, Expected)
 
 
+def test_call_tir_tuple_arg():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16, 16], "int32"), B: R.Tensor([16, 16], 
"int32")):
+            cls = Before
+            Prod = R.call_tir(cls.product, [A, B], out_sinfo=R.Tensor([16, 
16], "int32"))
+            Sum = R.call_tir(cls.sum, [A, B], out_sinfo=R.Tensor([16, 16], 
"int32"))
+            return (Prod, Sum)
+
+        @T.prim_func(private=True)
+        def product(
+            A: T.Buffer([16, 16], "int32"),
+            B: T.Buffer([16, 16], "int32"),
+            C: T.Buffer([16, 16], "int32"),
+        ):
+            for iters in T.grid(*A.shape):
+                with T.block("compute"):
+                    i, j = T.axis.remap("SS", iters)
+                    C[i, j] = A[i, j] * B[i, j]
+
+        @T.prim_func(private=True)
+        def sum(
+            A: T.Buffer([16, 16], "int32"),
+            B: T.Buffer([16, 16], "int32"),
+            C: T.Buffer([16, 16], "int32"),
+        ):
+            for iters in T.grid(*A.shape):
+                with T.block("compute"):
+                    i, j = T.axis.remap("SS", iters)
+                    C[i, j] = A[i, j] + B[i, j]
+
+    Expected = Before
+
+    # If EliminateCommonSubexpr produces unnormalized expressions,
+    # normalization of those expressions may produce additional
+    # variables bindings.  This test case should be agnostic to those
+    # additional bindings, so DCE is applied after CSE.
+    After = tvm.ir.transform.Sequential(
+        [
+            EliminateCommonSubexpr(),
+            tvm.relax.transform.DeadCodeElimination(),
+        ]
+    )(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/relax/test_transform_operator_specific_normalization.py 
b/tests/python/relax/test_transform_operator_specific_normalization.py
index 07d541ab1e..4ee1716645 100644
--- a/tests/python/relax/test_transform_operator_specific_normalization.py
+++ b/tests/python/relax/test_transform_operator_specific_normalization.py
@@ -22,7 +22,7 @@ import tvm.testing
 import tvm.relax.testing.transform
 
 from tvm import relax
-from tvm.script.parser import ir as I, relax as R
+from tvm.script.parser import ir as I, relax as R, tir as T
 
 import pytest
 
@@ -167,5 +167,211 @@ def test_un_normalized_call_node_is_ill_formed(custom_op, 
define_normalization):
         assert relax.analysis.well_formed(Module)
 
 
[email protected]_well_formed_check_before_transform
+def test_normalize_to_inline_tuple_for_call_tir(custom_op):
+    """FNormalize in-lines the argument tuple for R.call_tir"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            cls = Before
+            args = (A,)
+            return relax.Call(
+                tvm.ir.Op.get("relax.call_tir"),
+                [cls.multiply_by_two, args],
+                sinfo_args=[A.struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
+            for i in range(16):
+                B[i] = A[i] * 2.0
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            cls = Expected
+            args = (A,)
+            return relax.Call(
+                tvm.ir.Op.get("relax.call_tir"),
+                [cls.multiply_by_two, relax.Tuple([A])],
+                sinfo_args=[A.struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
+            for i in range(16):
+                B[i] = A[i] * 2.0
+
+    After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+    assert not tvm.ir.structural_equal(Before, After)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
+def test_normalize_argument_to_inline_tuple_for_call_tir(custom_op):
+    """FNormalize in-lines the argument tuple for R.call_tir
+
+    Like `test_normalize_to_inline_tuple_for_call_tir`, but the
+    argument tuple is provided as a relax function argument.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(args: R.Tuple([R.Tensor([16], "float32")])):
+            cls = Before
+            return relax.Call(
+                tvm.ir.Op.get("relax.call_tir"),
+                [cls.multiply_by_two, args],
+                sinfo_args=[args[0].struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
+            for i in range(16):
+                B[i] = A[i] * 2.0
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(args: R.Tuple([R.Tensor([16], "float32")])):
+            cls = Expected
+            return relax.Call(
+                tvm.ir.Op.get("relax.call_tir"),
+                [cls.multiply_by_two, relax.Tuple([args[0]])],
+                sinfo_args=[args[0].struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
+            for i in range(16):
+                B[i] = A[i] * 2.0
+
+    After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+    assert not tvm.ir.structural_equal(Before, After)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
+def test_normalize_to_inline_tuple_for_call_tir_inplace(custom_op):
+    """FNormalize in-lines the argument tuple for R.call_tir_inplace"""
+
+    # The CallTIRInplaceAttrs cannot be constructed from the Python
+    # API.  Therefore, declaring the Expected output first, so that
+    # the attributes can be used for the non-normalized Before.
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            cls = Expected
+            args = (A,)
+            return R.call_tir_inplace(
+                cls.multiply_by_two,
+                A,
+                inplace_indices=[0],
+                out_sinfo=[A.struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32")):
+            for i in range(16):
+                A[i] = A[i] * 2.0
+
+    inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            cls = Before
+            args = (A,)
+            return relax.Call(
+                tvm.ir.Op.get("relax.call_tir_inplace"),
+                [cls.multiply_by_two, args],
+                attrs=inplace_attrs,
+                sinfo_args=[A.struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32")):
+            for i in range(16):
+                A[i] = A[i] * 2.0
+
+    After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+    assert not tvm.ir.structural_equal(Before, After)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
+def test_normalize_to_inline_tuple_for_call_tir_with_grad(custom_op):
+    """FNormalize in-lines the argument tuple for R.call_tir_with_grad"""
+
+    # The CallTIRWithGradAttrs cannot be constructed from the Python
+    # API.  Therefore, declaring the Expected output first, so that
+    # the attributes can be used for the non-normalized Before.
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            cls = Expected
+            args = (A,)
+            return R.call_tir_with_grad(
+                cls.multiply_by_two,
+                A,
+                out_sinfo=[A.struct_info],
+                te_grad_name="f_grad",
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
+            for i in range(16):
+                B[i] = A[i] * 2.0
+
+        @T.prim_func(private=True)
+        def f_grad(
+            A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: 
T.Buffer(16, "float32")
+        ):
+            for i in range(16):
+                Grad[i] = 2.0
+
+    with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([16], "float32")):
+            cls = Before
+            args = (A,)
+            return relax.Call(
+                tvm.ir.Op.get("relax.call_tir_with_grad"),
+                [cls.multiply_by_two, args],
+                attrs=with_grad_attrs,
+                sinfo_args=[A.struct_info],
+            )
+
+        @T.prim_func(private=True)
+        def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
+            for i in range(16):
+                B[i] = A[i] * 2.0
+
+        @T.prim_func(private=True)
+        def f_grad(
+            A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad: 
T.Buffer(16, "float32")
+        ):
+            for i in range(16):
+                Grad[i] = 2.0
+
+    After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+    assert not tvm.ir.structural_equal(Before, After)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to