This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0ddfc657db [Unity] Implement FNormalize for relax.op.call_tir (#16068)
0ddfc657db is described below
commit 0ddfc657db410024e7261aa93b6109465c89484a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Nov 14 15:03:09 2023 -0600
[Unity] Implement FNormalize for relax.op.call_tir (#16068)
* [Unity] Implement FNormalize for relax.op.call_tir
Prior to this commit, `relax.op.call_tir` could express the TIR
arguments as either an in-line tuple, or as a variable bound to a
tuple. Because several passes assume the arguments will always be an
in-line tuple, this is being codified as the normal form of
`relax.op.call_tir`. Any upstream transform that produces
`relax.op.call_tir` with arguments provided as a by-variable tuple
will either be normalized to an in-line tuple if possible, or will
produce an error during the upstream transform otherwise.
This commit is specifically to allow the current usage of
`Downcast<Tuple>(call->args[1])` in passes such as `CallTIRRewrite`,
`FoldConstant`, `FuseTIR`, and `RewriteDataflowReshape`.
* Resolve unit test failures
* Added normalization for call_tir_inplace and call_tir_with_grad
* Normalize arg_tuple to (arg_tuple[0], ..., arg_tuple[N]) if unknown
---
src/relax/analysis/well_formed.cc | 16 +-
src/relax/op/op.cc | 89 +++++++--
tests/python/relax/test_transform_cse.py | 47 +++++
...st_transform_operator_specific_normalization.py | 208 ++++++++++++++++++++-
4 files changed, 341 insertions(+), 19 deletions(-)
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 5cb577e82b..9eed78d270 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -319,9 +319,19 @@ class WellFormedChecker : public relax::ExprVisitor,
if (auto func_normalize = op_map_normalize_.get(call->op, nullptr);
func_normalize != nullptr) {
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
- auto before_normalize = GetRef<Call>(call);
- auto after_normalize = func_normalize(dummy_builder, before_normalize);
- if (!before_normalize.same_as(after_normalize)) {
+ Call before_normalize = GetRef<Call>(call);
+ Optional<Expr> after_normalize = NullOpt;
+ try {
+ after_normalize = func_normalize(dummy_builder, before_normalize);
+ } catch (std::exception& err) {
+ Malformed(
+ Diagnostic::Error(call)
+ << "If an operator defines an operator-specific normalization
function (FNormalize), "
+ << "calls to that operator must be normalized with it. "
+ << "However, normalization of " << before_normalize << " resulted
in the error: \n"
+ << err.what());
+ }
+ if (after_normalize && !before_normalize.same_as(after_normalize)) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization
function (FNormalize), "
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index fe74286a51..f51d2cc74f 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -253,11 +253,70 @@ StructInfo InferStructInfoCallTIR(const Call& call, const
BlockBuilder& ctx) {
return call->sinfo_args[0];
}
-Expr NormalizeCallTIR(const BlockBuilder&, Call call) {
- // Temporary implementation to ensure that at least one op has a
- // registered value for FNormalize. This temporary implementation
- // is fully implemented in follow-up PR
- // https://github.com/apache/tvm/pull/16068.
+Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) {
+ // This function is used for normalization of `relax.call_tir`,
+ // along with the variants `relax.call_tir_with_grad` and
+ // `relax.call_tir_inplace`. Therefore, all error messages should
+ // be written in terms of `call->op`, and should not explicitly
+ // reference the `relax.call_tir` operator.`
+ CHECK(call->args.size() == 2 || call->args.size() == 3)
+ << "Operation " << call->op << " expects either two arguments [callee,
arg_tuple], "
+ << "or three arguments [callee, arg_tuple, tir_args], "
+ << "but " << call << " has " << call->args.size() << " arguments.";
+
+ Expr arg_expr = call->args[1];
+
+ CHECK(arg_expr->struct_info_.as<TupleStructInfoNode>())
+ << "Operation " << call->op << " expects the second argument to be a
tuple of relax Expr. "
+ << "However, the second argument " << arg_expr << " has struct info "
+ << arg_expr->struct_info_ << ".";
+
+ if (arg_expr.as<TupleNode>()) {
+ return std::move(call);
+ }
+
+ CHECK(arg_expr.as<VarNode>())
+ << "Operation " << call->op << " must hold its arguments as an in-line
tuple. "
+ << "However, " << call << " has arguments " << arg_expr
+ << ", which is neither an in-line tuple, "
+ << "nor a variable binding that may be normalized to an in-line tuple.";
+
+ auto unwrap_binding = [&ctx](Expr expr) -> Optional<Expr> {
+ if (auto var = expr.as<Var>()) {
+ if (auto bound_value = ctx->LookupBinding(var.value())) {
+ return bound_value.value();
+ }
+ }
+ return NullOpt;
+ };
+
+ while (auto unwrapped = unwrap_binding(arg_expr)) {
+ arg_expr = unwrapped.value();
+ }
+
+ Tuple new_arg_expr = [&]() {
+ // Preferred replacement. The argument tuple is provided as a
+ // variable, but we know the value bound to that variable.
+ if (auto opt = arg_expr.as<Tuple>()) {
+ return opt.value();
+ }
+
+ // Fallback case. The argument tuple is provided as a variable,
+ // and we don't know the value bound to that variable. For
+ // example, if a relax function accepted a tuple as an parameter,
+ // then provided that same tuple as an argument to call_tir.
+ Array<Expr> tuple_elements;
+ size_t num_fields =
Downcast<TupleStructInfo>(arg_expr->struct_info_)->fields.size();
+ for (size_t i = 0; i < num_fields; i++) {
+ tuple_elements.push_back(TupleGetItem(arg_expr, i));
+ }
+ return Tuple(tuple_elements);
+ }();
+
+ auto new_args = call->args;
+ new_args.Set(1, new_arg_expr);
+ call.CopyOnWrite()->args = new_args;
+
return std::move(call);
}
@@ -314,6 +373,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
"ShapeExpr representing a tuple of ints to unpack during
runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
+ .set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));
Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo>
out_sinfo_list,
@@ -353,14 +413,12 @@
TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWit
// call_tir_inplace
-StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder&
ctx) {
- if (call->sinfo_args.size() != 1) {
- ctx->ReportFatal(Diagnostic::Error(call)
- << "sinfo_args should have exactly 1 output struct
info.");
- }
- CHECK(call->args[0]->IsInstance<GlobalVarNode>())
- << "call_tir expects the first argument to be a GlobalVar referring to a
TIR PrimFunc. "
- << "However, gets " << call->args[0];
+Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
+ // Apply normalization before error checks. This allows the error
+ // checks to safely apply `Downcast<Tuple>(call->args[1])`, which
+ // may result in an error if performed before normalization.
+ call = Downcast<Call>(NormalizeCallTIR(ctx, std::move(call)));
+
// there must be an inplace index for each output
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
size_t num_outputs = 1U;
@@ -443,7 +501,7 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call,
const BlockBuilder& c
}
}
- return call->sinfo_args[0];
+ return std::move(call);
}
TVM_REGISTER_NODE_TYPE(CallTIRInplaceAttrs);
@@ -456,7 +514,8 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
.add_argument("packed_ints", "Expr",
"ShapeExpr representing a tuple of ints to unpack during
runtime. Omitted from "
"args if unused")
- .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCallTIRInplace)
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
+ .set_attr<FNormalize>("FNormalize", NormalizeCallTIRInPlace)
// Warning: considered pure, but it has the potential to create visible
effects!
// This should only be used if it has been *checked* that it is safe (no
aliases, in-place
// arguments will no longer be live)
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index 92cf4349d4..3a57afb22c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -292,5 +292,52 @@ def test_do_not_eliminate_extern_func():
verify(Before, Expected)
+def test_call_tir_tuple_arg():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16, 16], "int32"), B: R.Tensor([16, 16],
"int32")):
+ cls = Before
+ Prod = R.call_tir(cls.product, [A, B], out_sinfo=R.Tensor([16,
16], "int32"))
+ Sum = R.call_tir(cls.sum, [A, B], out_sinfo=R.Tensor([16, 16],
"int32"))
+ return (Prod, Sum)
+
+ @T.prim_func(private=True)
+ def product(
+ A: T.Buffer([16, 16], "int32"),
+ B: T.Buffer([16, 16], "int32"),
+ C: T.Buffer([16, 16], "int32"),
+ ):
+ for iters in T.grid(*A.shape):
+ with T.block("compute"):
+ i, j = T.axis.remap("SS", iters)
+ C[i, j] = A[i, j] * B[i, j]
+
+ @T.prim_func(private=True)
+ def sum(
+ A: T.Buffer([16, 16], "int32"),
+ B: T.Buffer([16, 16], "int32"),
+ C: T.Buffer([16, 16], "int32"),
+ ):
+ for iters in T.grid(*A.shape):
+ with T.block("compute"):
+ i, j = T.axis.remap("SS", iters)
+ C[i, j] = A[i, j] + B[i, j]
+
+ Expected = Before
+
+ # If EliminateCommonSubexpr produces unnormalized expressions,
+ # normalization of those expressions may produce additional
+ # variables bindings. This test case should be agnostic to those
+ # additional bindings, so DCE is applied after CSE.
+ After = tvm.ir.transform.Sequential(
+ [
+ EliminateCommonSubexpr(),
+ tvm.relax.transform.DeadCodeElimination(),
+ ]
+ )(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/relax/test_transform_operator_specific_normalization.py
b/tests/python/relax/test_transform_operator_specific_normalization.py
index 07d541ab1e..4ee1716645 100644
--- a/tests/python/relax/test_transform_operator_specific_normalization.py
+++ b/tests/python/relax/test_transform_operator_specific_normalization.py
@@ -22,7 +22,7 @@ import tvm.testing
import tvm.relax.testing.transform
from tvm import relax
-from tvm.script.parser import ir as I, relax as R
+from tvm.script.parser import ir as I, relax as R, tir as T
import pytest
@@ -167,5 +167,211 @@ def test_un_normalized_call_node_is_ill_formed(custom_op,
define_normalization):
assert relax.analysis.well_formed(Module)
[email protected]_well_formed_check_before_transform
+def test_normalize_to_inline_tuple_for_call_tir(custom_op):
+ """FNormalize in-lines the argument tuple for R.call_tir"""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ cls = Before
+ args = (A,)
+ return relax.Call(
+ tvm.ir.Op.get("relax.call_tir"),
+ [cls.multiply_by_two, args],
+ sinfo_args=[A.struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16,
"float32")):
+ for i in range(16):
+ B[i] = A[i] * 2.0
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ cls = Expected
+ args = (A,)
+ return relax.Call(
+ tvm.ir.Op.get("relax.call_tir"),
+ [cls.multiply_by_two, relax.Tuple([A])],
+ sinfo_args=[A.struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16,
"float32")):
+ for i in range(16):
+ B[i] = A[i] * 2.0
+
+ After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+ assert not tvm.ir.structural_equal(Before, After)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
+def test_normalize_argument_to_inline_tuple_for_call_tir(custom_op):
+ """FNormalize in-lines the argument tuple for R.call_tir
+
+ Like `test_normalize_to_inline_tuple_for_call_tir`, but the
+ argument tuple is provided as a relax function argument.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(args: R.Tuple([R.Tensor([16], "float32")])):
+ cls = Before
+ return relax.Call(
+ tvm.ir.Op.get("relax.call_tir"),
+ [cls.multiply_by_two, args],
+ sinfo_args=[args[0].struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16,
"float32")):
+ for i in range(16):
+ B[i] = A[i] * 2.0
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(args: R.Tuple([R.Tensor([16], "float32")])):
+ cls = Expected
+ return relax.Call(
+ tvm.ir.Op.get("relax.call_tir"),
+ [cls.multiply_by_two, relax.Tuple([args[0]])],
+ sinfo_args=[args[0].struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16,
"float32")):
+ for i in range(16):
+ B[i] = A[i] * 2.0
+
+ After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+ assert not tvm.ir.structural_equal(Before, After)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
+def test_normalize_to_inline_tuple_for_call_tir_inplace(custom_op):
+ """FNormalize in-lines the argument tuple for R.call_tir_inplace"""
+
+ # The CallTIRInplaceAttrs cannot be constructed from the Python
+ # API. Therefore, declaring the Expected output first, so that
+ # the attributes can be used for the non-normalized Before.
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ cls = Expected
+ args = (A,)
+ return R.call_tir_inplace(
+ cls.multiply_by_two,
+ A,
+ inplace_indices=[0],
+ out_sinfo=[A.struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32")):
+ for i in range(16):
+ A[i] = A[i] * 2.0
+
+ inplace_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ cls = Before
+ args = (A,)
+ return relax.Call(
+ tvm.ir.Op.get("relax.call_tir_inplace"),
+ [cls.multiply_by_two, args],
+ attrs=inplace_attrs,
+ sinfo_args=[A.struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32")):
+ for i in range(16):
+ A[i] = A[i] * 2.0
+
+ After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+ assert not tvm.ir.structural_equal(Before, After)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
[email protected]_well_formed_check_before_transform
+def test_normalize_to_inline_tuple_for_call_tir_with_grad(custom_op):
+ """FNormalize in-lines the argument tuple for R.call_tir_with_grad"""
+
+ # The CallTIRWithGradAttrs cannot be constructed from the Python
+ # API. Therefore, declaring the Expected output first, so that
+ # the attributes can be used for the non-normalized Before.
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ cls = Expected
+ args = (A,)
+ return R.call_tir_with_grad(
+ cls.multiply_by_two,
+ A,
+ out_sinfo=[A.struct_info],
+ te_grad_name="f_grad",
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16,
"float32")):
+ for i in range(16):
+ B[i] = A[i] * 2.0
+
+ @T.prim_func(private=True)
+ def f_grad(
+ A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad:
T.Buffer(16, "float32")
+ ):
+ for i in range(16):
+ Grad[i] = 2.0
+
+ with_grad_attrs = Expected["main"].body.blocks[0].bindings[1].value.attrs
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ cls = Before
+ args = (A,)
+ return relax.Call(
+ tvm.ir.Op.get("relax.call_tir_with_grad"),
+ [cls.multiply_by_two, args],
+ attrs=with_grad_attrs,
+ sinfo_args=[A.struct_info],
+ )
+
+ @T.prim_func(private=True)
+ def multiply_by_two(A: T.Buffer(16, "float32"), B: T.Buffer(16,
"float32")):
+ for i in range(16):
+ B[i] = A[i] * 2.0
+
+ @T.prim_func(private=True)
+ def f_grad(
+ A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32"), Grad:
T.Buffer(16, "float32")
+ ):
+ for i in range(16):
+ Grad[i] = 2.0
+
+ After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before)
+
+ assert not tvm.ir.structural_equal(Before, After)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()