ANSHUMAN87 commented on a change in pull request #6238:
URL: https://github.com/apache/incubator-tvm/pull/6238#discussion_r476308693



##########
File path: tests/python/unittest/test_tir_transform_hoist_if.py
##########
@@ -255,6 +259,488 @@ def test_multi_if():
                        ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
     verify_structure(new_stmt, expected_struct)
 
+def test_no_hoisting_1():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
+
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(k >= 3):
+                    data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 
0.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
+    x = te.var("x")
+
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(i >= 3):
+                    data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 
0.3
+                data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                ib.scope_attr(bx, "thread_extent", dshape_inner[1])

Review comment:
       As per my understanding, I think they should not be hoisted as such 
statement exists in real time to override the global  scope variable values in 
a particular sub-scope. But in case you have a real time scenario where the 
hoisting is applicable, please share with me. I will definitely take it up in 
my future works.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to