ANSHUMAN87 commented on a change in pull request #6238:
URL: https://github.com/apache/incubator-tvm/pull/6238#discussion_r475057328



##########
File path: tests/python/unittest/test_tir_transform_hoist_if.py
##########
@@ -255,6 +259,488 @@ def test_multi_if():
                        ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
     verify_structure(new_stmt, expected_struct)
 
+def test_no_hoisting_1():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
+
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(k >= 3):
+                    data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 
0.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    n = te.var("n")
+    x = te.var("x")
+
+    with ib.for_range(0, 10, "i") as i:
+        with ib.for_range(0, 10, "j") as j:
+            with ib.for_range(0, 10, "k") as k:
+                with ib.if_scope(i >= 3):
+                    data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 
0.3
+                data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_4():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_5():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_6():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope((tx + k) < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_7():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.if_scope((tx + j) < 9):
+                with ib.for_range(0, n, "k") as k:
+                    with ib.if_scope((tx + k) < 3):
+                        data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  
+ 0.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_1():
+    n = te.size_var("n")
+    m = te.size_var("m")
+    A = te.placeholder((n, m), name='A')
+    k = te.reduce_axis((0, m), "k")
+    B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+    s = te.create_schedule(B.op)
+    ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
+    BF = s.rfactor(B, ki)
+    xo, xi = s[B].split(s[B].op.axis[0], factor=32)
+    s[B.op].bind(xo, te.thread_axis("blockIdx.x"))
+    s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
+    s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
+    s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
+    func = tvm.driver.build_module.form_irmodule(
+            s, [A, B], "main", None)["main"]
+    stmt = func.body
+    new_stmt = 
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = 
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_2():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    #ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                ib.scope_attr(bx, "thread_extent", dshape[1])
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    #tvm.ir.assert_structural_equal(new_stmt, stmt)
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_3():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    dshape_inner = (33, 63)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+            ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(tx < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    #tvm.ir.assert_structural_equal(new_stmt, stmt)
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_4():
+    nn = 1024
+    n = tvm.runtime.convert(nn)
+    A = te.placeholder((n,), name='A')
+    B = te.placeholder((n,), name='B')
+    AA = te.compute((n,), lambda *i: A(*i), name='A')
+    BB = te.compute((n,), lambda *i: B(*i), name='B')
+    T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
+    C = te.compute(A.shape, lambda *i: T(*i), name='C')
+    s = te.create_schedule(C.op)
+    xo, xi = s[C].split(C.op.axis[0], factor=4)
+    xo1, xo2 = s[C].split(xo, factor=13)
+    s[C].parallel(xo2)
+    s[C].pragma(xo1, "parallel_launch_point")
+    s[C].pragma(xo2, "parallel_stride_pattern")
+    s[C].pragma(xo2, "parallel_barrier_when_finish")
+    s[C].vectorize(xi)
+    func = tvm.driver.build_module.form_irmodule(
+            s, [A, B, C], "main", None)["main"]
+    stmt = func.body
+    new_stmt = 
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = 
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_5():
+    ib = tvm.tir.ir_builder.create()
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+    g = te.var('g')
+
+    ib.scope_attr(data, "storage_scope", "global")
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope(data[g] < 3):
+                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]  + 0.3
+                with ib.else_scope():
+                    data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]  + 1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+    stmt = new_stmt
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_6():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope((tx + n) < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_7():
+    ib = tvm.tir.ir_builder.create()
+    dshape = (32, 64)
+    data = ib.pointer("float32", name="data")
+    l = te.var('l')
+    m = te.var('m')
+    n = te.var('n')
+
+    tx = te.thread_axis("threadIdx.x")
+    bx = te.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", dshape[0])
+    ib.scope_attr(bx, "thread_extent", dshape[1])
+    with ib.for_range(0, l, "i") as i:
+        with ib.for_range(0, m, "j") as j:
+            with ib.for_range(0, n, "k") as k:
+                with ib.if_scope((tx + i) < 3):
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
0.3
+                with ib.else_scope():
+                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 
1.3
+
+    stmt = ib.get()
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+    with tvm.transform.PassContext(config={
+        "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+    }):
+        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
[email protected]()
+def test_hoisting_op_conv():

Review comment:
       Note: While adding a performance bench-marking test case like this, i am 
facing a weird behavior in terms of time taken during execution. Sometimes the 
test case hangs for a very long time, even CI execution halted because of this.
   I am debugging it currently. That is why the test case is skipped for the 
time. 
   However i have verified the behavior is not due to the Hoisting Pass added. 
So i think we can continue with the current PR. Thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to