ANSHUMAN87 commented on a change in pull request #6238:
URL: https://github.com/apache/incubator-tvm/pull/6238#discussion_r476308693
##########
File path: tests/python/unittest/test_tir_transform_hoist_if.py
##########
@@ -255,6 +259,488 @@ def test_multi_if():
('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
+def test_no_hoisting_1():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(k >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+ x = te.var("x")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(i >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.3
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
Review comment:
As per my understanding, I think they should not be hoisted as such
statement exists in real time to override the global scope variable values in
a particular sub-scope. But in case you have a real time scenario where the
hoisting is applicable, please share with me. I will definitely take it up in
my future works.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]