This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 61fbf4262d [Unity] Fix FuseTIR when the same buffer is read multiple
times with different access pattern (#14603)
61fbf4262d is described below
commit 61fbf4262d1729a97174fb118f05852fd8afc291
Author: masahi <[email protected]>
AuthorDate: Wed Apr 12 23:18:01 2023 +0900
[Unity] Fix FuseTIR when the same buffer is read multiple times with
different access pattern (#14603)
When the same buffer is read multiple times with different access patterns
in a single expression, the check below fails
But this case should be allowed, for example in the following subgraph
inp_0 is used twice in different read regions. See the test case for details.
This subgraph arises if we run ConvertLayout on SD UNet from web-stable
diffusion.
---
src/relax/transform/fuse_tir.cc | 43 +++------
tests/python/relax/test_transform_fuse_tir.py | 120 ++++++++++++++++++++++++++
2 files changed, 133 insertions(+), 30 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 183d395b60..432ddca0a7 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -117,15 +117,13 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr&
n, const PrimExpr& othe
/*!
* \brief Substitute a given source buffer with a given target buffer in
statements or expressions.
*/
-class FuseTIRBufferSubstitor : private StmtExprMutator {
+class FuseTIRBufferSubstitutor : private StmtExprMutator {
public:
- explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map,
- const Map<Var, Var>& var_map) {
+ explicit FuseTIRBufferSubstitutor(const Map<Buffer, Buffer>& buffer_map,
+ const Map<Var, Var>& var_map) {
buffer_remap_ = buffer_map;
var_remap_ = var_map;
- for (const auto& kv : buffer_map) {
- const Buffer& src = kv.first;
- const Buffer& tgt = kv.second;
+ for (const auto& [src, tgt] : buffer_map) {
var_remap_.Set(src->data, tgt->data);
}
}
@@ -246,8 +244,6 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
Map<tir::Buffer, tir::Buffer> buffer_remap_;
/*! \brief Mapping from src tir var to tgt var. */
Map<tir::Var, tir::Var> var_remap_;
- /*! \brief The structural equality checker */
- StructuralEqual structural_equal_;
Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>&
regions) const {
// For now we only allow Buffer access the same elements.
@@ -262,8 +258,6 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
if (it == buffer_region_set.end()) {
ret.push_back(region);
buffer_region_set[region->buffer.get()] = region->region;
- } else {
- ICHECK(structural_equal_(region->region, it->second));
}
}
@@ -351,10 +345,8 @@ class FusedTIRConstructor : public ExprVisitor {
// It's a symbolic shape var, no need to alloc Buffers.
continue;
}
- auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), //
- relax_param->name_hint());
- const Array<tir::Var>& params = ret.first;
- const Array<tir::Buffer>& buffers = ret.second;
+ auto [params, buffers] =
CreateParamsAndBuffers(GetStructInfo(relax_param), //
+
relax_param->name_hint());
ICHECK_EQ(params.size(), buffers.size());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.buffer_map.Set(params[i], buffers[i]);
@@ -384,10 +376,8 @@ class FusedTIRConstructor : public ExprVisitor {
// Step 4. Append symbolic vars
const relax::Var& last_relax_param = func->params.back();
if (GetStructInfo(last_relax_param)->IsInstance<ShapeStructInfoNode>()) {
- auto ret =
+ auto [params, buffers] =
CreateParamsAndBuffers(GetStructInfo(last_relax_param),
last_relax_param->name_hint());
- const Array<tir::Var>& params = ret.first;
- const Array<tir::Buffer>& buffers = ret.second;
ICHECK(buffers.empty());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.params.push_back(params[i]);
@@ -682,9 +672,7 @@ class FusedTIRConstructor : public ExprVisitor {
"list.";
if (index == -1) index = 0;
for (size_t i = 0; i < tuple->fields.size(); ++i) {
- auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
- const Array<tir::Var>& ret_params = ret.first;
- const Array<tir::Buffer>& ret_buffers = ret.second;
+ auto [ret_params, ret_buffers] =
CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
ICHECK_EQ(ret_params.size(), ret_buffers.size());
// Adding tuple field results to the end of params and buffers.
params.insert(params.end(), ret_params.begin(), ret_params.end());
@@ -714,19 +702,18 @@ class FusedTIRConstructor : public ExprVisitor {
tir::PrimFunc ConstructFunc() {
Map<String, ObjectRef> attr_map;
attr_map.Set("tir.noalias", tir::const_true());
- tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map,
- func_info_.symbolic_var_remap);
+ tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map,
func_info_.symbolic_var_remap);
ICHECK(func_info_.global_name != "fused");
// Remove output buffers from func_info_.alloc_buffers
Array<tir::Buffer> alloc_buffers;
for (const tir::Buffer& buf : func_info_.alloc_buffers) {
if (func_info_.output_buffers.count(buf.get()) == 0) {
- alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf));
+ alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf));
}
}
tir::Stmt body =
tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));
- body = substitor.Substitute(body);
+ body = subst.Substitute(body);
body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt,
alloc_buffers);
body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
tir::PrimFunc func(func_info_.params, body, VoidType(),
func_info_.buffer_map,
@@ -804,9 +791,7 @@ class TIRFuseMutator : public ExprMutator {
// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty
block builder.
TIRFuseMutator mutator(mod);
// Step 1. Fuse all primitive relax functions, store the result in
`fused_tir_funcs_`
- for (const auto& kv : mod->functions) {
- const GlobalVar& gv = kv.first;
- const BaseFunc& func = kv.second;
+ for (const auto& [gv, func] : mod->functions) {
// Only fuse primitive relax functions
if (func->IsInstance<relax::FunctionNode>() &&
func->HasNonzeroAttr(attr::kPrimitive)) {
tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
@@ -816,9 +801,7 @@ class TIRFuseMutator : public ExprMutator {
// Step 2. Update all non-primitive relax functions and add it, with the
dependent function,
// into the new IRModule
- for (const auto& kv : mod->functions) {
- const GlobalVar& gv = kv.first;
- const BaseFunc& func = kv.second;
+ for (const auto& [gv, func] : mod->functions) {
if (func->IsInstance<relax::FunctionNode>() &&
!func->HasNonzeroAttr(attr::kPrimitive)) {
relax::Function update_func =
Downcast<Function>(mutator.VisitExpr(func));
mutator.builder_->AddFunction(update_func, gv->name_hint);
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 8b856d3cc5..c7aa7984be 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -883,5 +883,125 @@ def test_symbolic_var_in_call_tir_args():
_check(Before, Expected)
+def test_same_buffer_multiple_read():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def concatenate(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(4), T.int64(64),
T.int64(64)), "float32"),
+ rxplaceholder_1: T.Buffer(
+ (T.int64(1), T.int64(4), T.int64(64), T.int64(64)), "float32"
+ ),
+ T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64),
T.int64(64)), "float32"),
+ ):
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4),
T.int64(64), T.int64(64)):
+ with T.block("T_concat"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2,
v_ax3],
+ rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
+ )
+ T.writes(T_concat[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
+ T.int64(1) <= v_ax0,
+ rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2,
v_ax3],
+ rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
+ )
+
+ @T.prim_func
+ def transpose2(
+ rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64),
T.int64(64)), "float32"),
+ T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64),
T.int64(4)), "float32"),
+ ):
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64),
T.int64(64), T.int64(4)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax3, v_ax1, v_ax2])
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[
+ v_ax0, v_ax3, v_ax1, v_ax2
+ ]
+
+ @R.function
+ def fused_concatenate_transpose2(
+ inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
+ ) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.concatenate,
+ (inp_0, inp_0),
+ out_sinfo=R.Tensor((2, 4, 64, 64), dtype="float32"),
+ )
+ gv = R.call_tir(
+ cls.transpose2, (lv,), out_sinfo=R.Tensor((2, 64, 64, 4),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
+ ) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ cls = Module
+ with R.dataflow():
+ lv = cls.fused_concatenate_transpose2(inp_0)
+ R.output(lv)
+ return lv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def fused_concatenate_transpose2(
+ inp_0: T.Buffer((T.int64(1), T.int64(4), T.int64(64),
T.int64(64)), "float32"),
+ T_transpose_handle_intermediate: T.Buffer(
+ (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ T_concat_handle_intermediate = T.alloc_buffer(
+ (T.int64(2), T.int64(4), T.int64(64), T.int64(64))
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4),
T.int64(64), T.int64(64)):
+ with T.block("T_concat"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3])
+ T.writes(T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_concat_handle_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] =
T.if_then_else(
+ T.int64(1) <= v_ax0,
+ inp_0[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
+ inp_0[v_ax0, v_ax1, v_ax2, v_ax3],
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64),
T.int64(64), T.int64(4)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1,
v_ax2])
+ T.writes(T_transpose_handle_intermediate[v_ax0, v_ax1,
v_ax2, v_ax3])
+ T_transpose_handle_intermediate[
+ v_ax0, v_ax1, v_ax2, v_ax3
+ ] = T_concat_handle_intermediate[v_ax0, v_ax3, v_ax1,
v_ax2]
+
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 4, 64, 64), dtype="float32")
+ ) -> R.Tensor((2, 64, 64, 4), dtype="float32"):
+ R.func_attr({"num_input": 3})
+ cls = Expected
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.fused_concatenate_transpose2,
+ (inp_0,),
+ out_sinfo=R.Tensor((2, 64, 64, 4), dtype="float32"),
+ )
+ R.output(lv)
+ return lv
+
+ _check(Module, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()