This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 15ba19fa78 [Unity][BYOC] Assign group to unused bindings and ignroe
PrimFunc (#14139)
15ba19fa78 is described below
commit 15ba19fa78148e6d9146fdba4539a0d9ba1dbf47
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Feb 27 12:20:39 2023 -0800
[Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc (#14139)
* [Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc
* Update fuse_ops.cc
---
src/relax/transform/fuse_ops.cc | 46 ++++----
src/relax/transform/run_codegen.cc | 3 +
.../relax/test_transform_fuse_ops_by_pattern.py | 121 ++++++++++++++++++++-
3 files changed, 144 insertions(+), 26 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 813c0c8f03..c5042d0191 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -890,14 +890,6 @@ IRModule MakeGroupedFunctions(
return OperatorFusor(mod, partition, lift_constants).Transform();
}
-static Map<Expr, Var> GetBindingInverse(const Map<Var, Expr>& binding) {
- Map<Expr, Var> value_to_bound_var;
- for (const auto& [var, val] : binding) {
- value_to_bound_var.Set(val, var);
- }
- return value_to_bound_var;
-}
-
/*! \brief Create a "partitioning", a map from interior / leaf expr to its
representative group,
* based on the provided pattern. The result can be passed to OperatorFusor
above to fuse operations
* in a group and create a grouped function.
@@ -909,21 +901,26 @@ class PatternBasedPartitioner : ExprVisitor {
using ExprVisitor::VisitExpr_;
static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr,
support::Arena* arena) {
- PatternBasedPartitioner part(pattern_name, pattern,
AnalyzeVar2Value(expr));
- // Initialize each expr to have its own group
- PostOrderVisit(
- expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] =
arena->make<Group>(); });
+ PatternBasedPartitioner part(pattern_name, pattern, arena);
part.VisitExpr(expr);
return part.group_map_;
}
- PatternBasedPartitioner(String pattern_name, DFPattern pattern, const
Map<Var, Expr>& bindings)
- : pat_name_(pattern_name),
- pat_(pattern),
- bindings_(bindings),
- value_to_bound_var_(GetBindingInverse(bindings)) {}
+ PatternBasedPartitioner(String pattern_name, DFPattern pattern,
support::Arena* arena)
+ : pat_name_(pattern_name), pat_(pattern), arena_(arena) {}
+
+ void VisitVarDef(const Var& var) final { group_map_[var.get()] =
arena_->make<Group>(); }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ bindings_.Set(binding->var, binding->value);
+ value_to_bound_var_.Set(binding->value, binding->var);
+ ExprVisitor::VisitBinding_(binding);
+ }
+
+ void VisitExpr_(const ConstantNode* op) final { group_map_[op] =
arena_->make<Group>(); }
- void VisitExpr_(const CallNode* call) override {
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
+ VisitVarDef(binding->var);
if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call),
bindings_)) {
// If a match is found, put all matching expressions into the same group.
// OperatorFusor also requires that the bound variable be in the same
group as the RHS value.
@@ -939,15 +936,12 @@ class PatternBasedPartitioner : ExprVisitor {
// conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
// parent_group corresponds to the group of "conv1" above.
- auto parent_group = GetGroupForBoundVar(GetRef<Call>(call));
+ auto parent_group = GetGroupForBoundVar(binding->var);
ICHECK(parent_group);
parent_group->attrs.Set(attr::kComposite, pat_name_);
-
for (const auto& [pat, match] : matches_opt.value()) {
- ICHECK(group_map_.count(match.get()));
// Put all matching call nodes into the parent group.
if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call))
{
- AddToGroup(match, parent_group);
// Put the bound variable on the LHS into the same parent group.
AddToGroup(value_to_bound_var_[match], parent_group);
}
@@ -964,15 +958,14 @@ class PatternBasedPartitioner : ExprVisitor {
}
}
- Group* GetGroupForBoundVar(Expr e) {
- ICHECK(value_to_bound_var_.count(e));
- auto bound_var = value_to_bound_var_[e];
+ Group* GetGroupForBoundVar(const Var& bound_var) {
ICHECK(group_map_.count(bound_var.get()));
return group_map_[bound_var.get()]->FindRoot();
}
String pat_name_;
DFPattern pat_;
+ support::Arena* arena_;
Map<Var, Expr> bindings_;
Map<Expr, Var> value_to_bound_var_;
GroupMap group_map_;
@@ -1055,6 +1048,9 @@ IRModule FuseOpsByPattern(const tvm::Array<String>&
pattern_names,
for (size_t i = 0; i < pattern_names.size(); ++i) {
OperatorFusor::GroupMap group_map;
for (const auto& entry : mod->functions) {
+ if (entry.second->IsInstance<tir::PrimFuncNode>()) {
+ continue;
+ }
auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i],
entry.second, &arena);
group_map.insert(map.begin(), map.end());
}
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
index 114b7d2a34..7deeb139d1 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -138,6 +138,9 @@ class CodeGenRunner : ExprMutator {
std::unordered_map<std::string, Array<Function>> target_functions;
for (const auto& entry : mod->functions) {
+ if (entry.second->IsInstance<tir::PrimFuncNode>()) {
+ continue;
+ }
PostOrderVisit(entry.second, [&target_functions](Expr e) {
if (e->IsInstance<FunctionNode>()) {
auto f = Downcast<Function>(e);
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index da5b92fb64..21f952096b 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -20,7 +20,7 @@ import numpy as np
import tvm
from tvm import relax
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T, ir as I
from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op,
wildcard
@@ -460,5 +460,124 @@ def test_multiple_calls_same_extern():
check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned,
annoatate_codegen=True)
+def test_ignore_call_tir():
+ @I.ir_module
+ class Conv2dReLUCallTIR:
+ @T.prim_func
+ def relu(
+ data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64,
64, 56, 56), "float32")
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+ with T.block("root"):
+ i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ out[i, j, k, l] = T.max(data[i, j, k, l], 0.0)
+
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
+ relu1 = R.call_tir(relu, (conv1,), R.Tensor((64, 64, 56, 56),
"float32"))
+ R.output(relu1)
+
+ return relu1
+
+ @I.ir_module
+ class Conv2dReLUCallTIR_partitioned:
+ @T.prim_func
+ def relu(
+ data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64,
64, 56, 56), "float32")
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+ with T.block("root"):
+ i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(data[i, j, k, l])
+ T.writes(out[i, j, k, l])
+ out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))
+
+ @R.function
+ def fused_relax_nn_conv2d(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data,
+ weight1,
+ padding=(1, 1),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d(
+ data, weight1
+ )
+ relu1 = R.call_tir(
+ relu, (lv,), out_sinfo=R.Tensor((64, 64, 56, 56),
dtype="float32")
+ )
+ R.output(relu1)
+ return relu1
+
+ pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
+ check(Conv2dReLUCallTIR, [("cutlass.conv2d", pat)],
Conv2dReLUCallTIR_partitioned)
+
+
+def test_unused():
+ @I.ir_module
+ class Conv2dReLU:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
+ relu = R.nn.relu(data)
+ R.output(conv1)
+
+ return conv1
+
+ @I.ir_module
+ class Conv2dReLU_partitioned:
+ @R.function
+ def fused_relax_nn_conv2d(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data, weight1, padding=(1, 1)
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d(
+ data, weight1
+ )
+ relu: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.relu(data)
+ R.output(gv)
+ return gv
+
+ pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
+ check(Conv2dReLU, [("cutlass.conv2d", pat)], Conv2dReLU_partitioned)
+
+
if __name__ == "__main__":
pytest.main([__file__])