MarisaKirisame commented on a change in pull request #6238:
URL: https://github.com/apache/incubator-tvm/pull/6238#discussion_r475549695
##########
File path: python/tvm/tir/transform/transform.py
##########
@@ -500,11 +500,19 @@ def VerifyMemory():
"""
return _ffi_api.VerifyMemory()
-def HoistIfThenElse():
+def HoistIfThenElse(variant=None):
"""Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
+
+ Parameters
+ ----------
+ variant : str
+ The variant of the pass.
+
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.HoistIfThenElse()
+ if variant is None:
+ return _ffi_api.HoistIfThenElse()
+ return _ffi_api.HoistIfThenElseBasic()
Review comment:
check is basic
##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -126,10 +167,19 @@ class HoistCandidateSelector final : public
StmtExprVisitor {
// Maintain list of all vars in AttrStmt
// To stop hoisting if any of the block variables are used.
//
- // NOTE: If in future
- // hoisting is required for any specific case,
- // then add exception to only those case
- // rather than allowing for all.
+ // In case we want to use hoisting in between certain passes
+ // which have interdependencies of the postioning of if nodes with scope
var
+ // it is better to disable this section
+ if (support_block_scope_hosting_) {
+ if (IsRecordingOn()) {
+ StartOrAddRecord(GetRef<ObjectRef>(op));
+ StmtExprVisitor::VisitStmt_(op);
+ RemoveRecord(GetRef<ObjectRef>(op));
+ return;
+ }
+
+ return StmtExprVisitor::VisitStmt_(op);
Review comment:
if else instead of early return
##########
File path: tests/python/unittest/test_tir_transform_hoist_if.py
##########
@@ -255,6 +259,488 @@ def test_multi_if():
('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
+def test_no_hoisting_1():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(k >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_2():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ n = te.var("n")
+ x = te.var("x")
+
+ with ib.for_range(0, 10, "i") as i:
+ with ib.for_range(0, 10, "j") as j:
+ with ib.for_range(0, 10, "k") as k:
+ with ib.if_scope(i >= 3):
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] +
0.3
+ data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_4():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_5():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_6():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + k) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_no_hoisting_7():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.if_scope((tx + j) < 9):
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + k) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k]
+ 0.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_1():
+ n = te.size_var("n")
+ m = te.size_var("m")
+ A = te.placeholder((n, m), name='A')
+ k = te.reduce_axis((0, m), "k")
+ B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
+ s = te.create_schedule(B.op)
+ ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
+ BF = s.rfactor(B, ki)
+ xo, xi = s[B].split(s[B].op.axis[0], factor=32)
+ s[B.op].bind(xo, te.thread_axis("blockIdx.x"))
+ s[B.op].bind(xi, te.thread_axis("threadIdx.y"))
+ s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x"))
+ s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
+ func = tvm.driver.build_module.form_irmodule(
+ s, [A, B], "main", None)["main"]
+ stmt = func.body
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_2():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ #ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ #tvm.ir.assert_structural_equal(new_stmt, stmt)
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_3():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ dshape_inner = (33, 63)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ ib.scope_attr(tx, "thread_extent", dshape_inner[0])
+ ib.scope_attr(bx, "thread_extent", dshape_inner[1])
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope(tx < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ #tvm.ir.assert_structural_equal(new_stmt, stmt)
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_4():
+ nn = 1024
+ n = tvm.runtime.convert(nn)
+ A = te.placeholder((n,), name='A')
+ B = te.placeholder((n,), name='B')
+ AA = te.compute((n,), lambda *i: A(*i), name='A')
+ BB = te.compute((n,), lambda *i: B(*i), name='B')
+ T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
+ C = te.compute(A.shape, lambda *i: T(*i), name='C')
+ s = te.create_schedule(C.op)
+ xo, xi = s[C].split(C.op.axis[0], factor=4)
+ xo1, xo2 = s[C].split(xo, factor=13)
+ s[C].parallel(xo2)
+ s[C].pragma(xo1, "parallel_launch_point")
+ s[C].pragma(xo2, "parallel_stride_pattern")
+ s[C].pragma(xo2, "parallel_barrier_when_finish")
+ s[C].vectorize(xi)
+ func = tvm.driver.build_module.form_irmodule(
+ s, [A, B, C], "main", None)["main"]
+ stmt = func.body
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt =
tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_5():
+ ib = tvm.tir.ir_builder.create()
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+ g = te.var('g')
+
+ ib.scope_attr(data, "storage_scope", "global")
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope(data[g] < 3):
+ data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3
+ with ib.else_scope():
+ data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+ stmt = new_stmt
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+def test_hoisting_block_scope_6():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + n) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
+def test_hoisting_block_scope_7():
+ ib = tvm.tir.ir_builder.create()
+ dshape = (32, 64)
+ data = ib.pointer("float32", name="data")
+ l = te.var('l')
+ m = te.var('m')
+ n = te.var('n')
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", dshape[0])
+ ib.scope_attr(bx, "thread_extent", dshape[1])
+ with ib.for_range(0, l, "i") as i:
+ with ib.for_range(0, m, "j") as j:
+ with ib.for_range(0, n, "k") as k:
+ with ib.if_scope((tx + i) < 3):
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
0.3
+ with ib.else_scope():
+ data[bx * j + tx * j * k] = data[bx * j + tx * j * k] +
1.3
+
+ stmt = ib.get()
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ tvm.ir.assert_structural_equal(new_stmt, stmt)
+
+ with tvm.transform.PassContext(config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ assert(not tvm.ir.structural_equal(new_stmt, stmt))
+
[email protected]()
+def test_hoisting_op_conv():
+ dtype = "float32"
+ dshape = (1, 80, 73, 73)
+ kshape = (192, 80, 3, 3)
+ padding=(1, 1)
+ groups=1
+ dilation=(1, 1)
+ kernel_size=(3, 3)
+ channels=192
+ scale=1
+ x = relay.var("x", shape=dshape, dtype=dtype)
+ w = relay.var("w", shape=kshape, dtype=dtype)
+ y = relay.nn.conv2d(x, w, padding=padding,
+ dilation=dilation,
+ groups=groups,
+ channels=channels,
+ kernel_size=kernel_size)
+
+ func = relay.Function([x, w], y)
+ mod = tvm.IRModule()
+ mod['main'] = func
+ mod = relay.transform.InferType()(mod)
+
+ data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
+ kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
+
+ params = {'w': tvm.nd.array(kernel)}
+ for target, ctx in ctx_list():
+ with tvm.transform.PassContext(opt_level=3):
+ graph, lib, params = relay.build_module.build(mod, target=target,
params=params)
+ m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+ x = np.random.uniform(size=dshape)
+ data_tvm = tvm.nd.array(data)
+ m.set_input('x', data_tvm)
+ m.set_input(**params)
+ m.run()
+ e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
+ t1 = e(data_tvm).results
+ t1 = np.array(t1) * 1000
+ print('{} ms'.format(t1.mean()))
+
+ with tvm.transform.PassContext(opt_level=3, config={
+ "tir.HoistIfThenElse": {"support_block_scope_hosting": True}
+ }):
+ graph, lib, params = relay.build_module.build(mod, target=target,
params=params)
+ m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+ x = np.random.uniform(size=dshape)
+ data_tvm = tvm.nd.array(data)
+ m.set_input('x', data_tvm)
+ m.set_input(**params)
+ m.run()
+ e = m.module.time_evaluator("run", ctx, number=300, repeat=3)
+ t2 = e(data_tvm).results
+ t2 = np.array(t2) * 1000
+
+ print('{} ms'.format(t2.mean()))
+ tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1)
if __name__ == "__main__":
test_hoist_top_for()
Review comment:
import pytest
pytest.main([\__file__])
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]