ANSHUMAN87 commented on a change in pull request #6238:
URL: https://github.com/apache/incubator-tvm/pull/6238#discussion_r475057328
##########
File path: tests/python/unittest/test_tir_transform_hoist_if.py
##########
@@ -255,6 +259,488 @@ def test_multi_if():
('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
+def test_no_hoisting_1():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(k >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+ x = te.var("x")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(i >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.3
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_4():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_5():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_6():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + k) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_7():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.if_scope((tx + j) < 9):
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + k) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k]
+ 0.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_1():
+ n = te.size_var("n")
+ m = te.size_var("m")
+ A = te.placeholder((n, m), name='A')
+ k = te.reduce_axis((0, m), "k")
+ B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+ s = te.create_schedule(B.op)
+ ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
+ BF = s.rfactor(B, ki)
+ xo, xi = s[B].split(s[B].op.axis[0], factor=32)
+ s[B.op].bind(xo, te.thread_axis("blockIdx.x"))
+ s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
+ s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
+ s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
+ func = tvm.driver.build_module.form_irmodule(
+ s, [A, B], "main", None)["main"]
+ stmt = func.body
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_2():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ #ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ #tvm.ir.assert_structural_equal(new_stmt, stmt)
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ #tvm.ir.assert_structural_equal(new_stmt, stmt)
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_4():
+ nn = 1024
+ n = tvm.runtime.convert(nn)
+ A = te.placeholder((n,), name='A')
+ B = te.placeholder((n,), name='B')
+ AA = te.compute((n,), lambda *i: A(*i), name='A')
+ BB = te.compute((n,), lambda *i: B(*i), name='B')
+ T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
+ C = te.compute(A.shape, lambda *i: T(*i), name='C')
+ s = te.create_schedule(C.op)
+ xo, xi = s[C].split(C.op.axis[0], factor=4)
+ xo1, xo2 = s[C].split(xo, factor=13)
+ s[C].parallel(xo2)
+ s[C].pragma(xo1, "parallel_launch_point")
+ s[C].pragma(xo2, "parallel_stride_pattern")
+ s[C].pragma(xo2, "parallel_barrier_when_finish")
+ s[C].vectorize(xi)
+ func = tvm.driver.build_module.form_irmodule(
+ s, [A, B, C], "main", None)["main"]
+ stmt = func.body
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_5():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+ g = te.var('g')
+
+ ib.scope_attr(data, "storage_scope", "global")
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope(data[g] < 3):
+ data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3
+ with ib.else_scope():
+ data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+ stmt = new_stmt
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_6():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + n) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_7():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + i) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
[email protected]()
+def test_hoisting_op_conv():
Review comment:
Note: While adding a performance bench-marking test case like this, i am
facing a weird behavior in terms of time taken during execution. Sometimes the
test case hangs for a very long time, even CI execution halted because of this.
I am debugging it currently. That is why the test case is skipped for the
time.
However i have verified the behavior is not due to the Hoisting Pass added.
So i think we can continue with the current PR. Thanks!
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]