roastduck commented on a change in pull request #6238:
URL: https://github.com/apache/incubator-tvm/pull/6238#discussion_r475371697
##########
File path: python/tvm/driver/build_module.py
##########
@@ -181,7 +181,7 @@ def lower(sch,
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
- tvm.tir.transform.HoistIfThenElse(),
+ tvm.tir.transform.HoistIfThenElse("basic"),
Review comment:
Do we really need this "basic" pass? Can we directly perform a complete
pass in Phase 3?
##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -500,11 +500,19 @@ def VerifyMemory():
"""
return _ffi_api.VerifyMemory()
-def HoistIfThenElse():
+def HoistIfThenElse(variant=None):
"""Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
+
+ Parameters
+ ----------
+ variant : str
+ The variant of the pass.
+
Review comment:
Can you explain different variants here in the doc string?
##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -277,18 +327,21 @@ class HoistCandidateSelector final : public
StmtExprVisitor {
return false;
}
- std::vector<const ForNode*> ordered_for_list_;
+ std::vector<const Object*> ordered_list_;
Review comment:
Can you add a comment here to clarify that there can only be `ForNode`
and `AttrNode` in `ordered_list_`?
##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -147,26 +197,23 @@ class HoistCandidateSelector final : public
StmtExprVisitor {
if (CheckValidIf()) {
// Check corresponding for loop
- bool match_found = false;
- size_t match_for_loop_pos = 0;
+ int match_for_loop_pos = -1;
for (auto var : if_var_list_) {
- for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
- if (ordered_for_list_[i] == var_for_map_[var]) {
+ for (int i = 0; i < static_cast<int>(ordered_list_.size()); ++i) {
Review comment:
Suggestion: A better way to write this kind of loops is like `for (int i
= 0, i_end = static_cast<int>(xxx.size()); i < i_end; i++)`. This is because it
is hard to optimize for loop invariant in a C++ compiler. Usually `.size()`
will be called in every iteration. This is just a suggestion. You can keep your
original code since this is not a performance bottleneck.
##########
File path: tests/python/unittest/test_tir_transform_hoist_if.py
##########
@@ -255,6 +259,488 @@ def test_multi_if():
('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
+def test_no_hoisting_1():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(k >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+ x = te.var("x")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(i >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.3
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
Review comment:
Can we hoist the attributes together in this case? If it is correct and
never leads to performance degradation, we can leave it for future works.
##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -93,11 +112,33 @@ using HoistForIfTuple = std::tuple<bool, const ForNode*,
const IfThenElseNode*>;
* if (likely(j > 2))
* A[i+j+k] = B[i+j+k]
*
+ *
+ * This pass do hoisting for Block scope variables also.
Review comment:
The term "Block scope variables" is a little bit confusing to me. Is it
equivalent to say variables defined in `Attr` nodes?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]