masahi commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r835675916
##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +326,227 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4,
atol=2e-4)
+@register
+def int32x16(imm, span):
+ return imm.astype("int32x16", span)
+
+
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+ for i in T.serial(0, 16):
+ with T.init():
+ C[i] = T.int32(0)
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x64 = B.vload([0, 0], dtype="int8x64")
+ B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+ C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this
is an update +=
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+ T.uint32(0),
+ T.int32x16(0),
+ T.broadcast(A_i32, 16),
+ B_i32x16,
+ dtype="int32x16",
+ )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+ post_blocks = sch.get_consumers(block)
+
+ if len(post_blocks) > 0:
+ while True:
+ next_post_blocks = []
+ for post_block in post_blocks:
+ next_consumers = sch.get_consumers(post_block)
+
+ if len(next_consumers) > 0:
+ sch.compute_inline(post_block)
+
+ next_post_blocks += next_consumers
+
+ if len(next_post_blocks) == 0:
+ assert len(post_blocks) == 1
+ outer_block = post_blocks[0]
+ a_y, a_x = sch.get_loops(outer_block)[-2:]
+ break
+
+ post_blocks = next_post_blocks
+ else:
+ a_y, a_x, _ = sch.get_loops(block)[-3:]
+ outer_block = block
+
+ if do_tune:
+ y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+ a_yo, a_yi = sch.split(a_y, factors=y_factors)
+ else:
+ a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+ a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+ sch.reorder(a_yo, a_xo, a_yi, a_xi)
+ fused = sch.fuse(a_yo, a_xo)
+
+ if outer_block != block:
+ sch.vectorize(a_xi)
+ sch.compute_at(block, a_yi)
+
+ a_xi, a_k = sch.get_loops(block)[-2:]
+ a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+ sch.reorder(a_ko, a_xi, a_ki)
+
+ sch.parallel(fused)
+
+ dec = sch.decompose_reduction(block, a_ko)
+
+ init_loop = sch.get_loops(dec)[-1]
+ sch.vectorize(init_loop)
+
+ sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+ M, N, K = 1024, 1024, 1024
+ data_shape = (M, K)
+ weight_shape = (N, K)
+
+ data_dtype = "uint8"
+ data = relay.var("data", shape=data_shape, dtype=data_dtype)
+ weight = relay.var("weight", shape=weight_shape, dtype="int8")
+ bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+ # dense is tuned by the TIR schedule above, bmm is scheduled by TE
(topi/x86/batch_matmul.py)
+ dense = relay.nn.dense(data, weight, out_dtype="int32")
+ bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+ out = relay.nn.batch_matmul(
+ relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+ relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+ out_dtype="int32",
+ )
+
+ relay_mod = tvm.IRModule.from_expr(out)
+
+ target = "llvm -mcpu=cascadelake -num-cores 4"
+ dev = tvm.device(target, 0)
+
+ data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+ weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+ bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+ ref = (
+ relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+ .evaluate()(*[data, weight_np, bias_np])
+ .numpy()
+ )
+
+ params = {"weight": weight_np, "bias": bias_np}
+
+ extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+ tune_tasks = list(
+ filter(
+ lambda task: "dense" in task.task_name,
+ extracted_tasks,
+ )
+ )
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ if do_tune:
+ config = ReplayTraceConfig(
+ num_trials_per_iter=64,
+ num_trials_total=64,
+ )
+ database = tune_extracted_tasks(tune_tasks, target, config,
work_dir=work_dir)
+ else:
+ database = JSONDatabase(
+ path_workload=osp.join(work_dir, "database_workload.json"),
+ path_tuning_record=osp.join(work_dir,
"database_tuning_record.json"),
+ )
+
+ for task in tune_tasks:
+ mod = Parse._mod(task.dispatched[0])
+ workload = database.commit_workload(mod)
+
+ sch = tvm.tir.Schedule(mod)
+ block = sch.get_block("compute")
+ schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+ if "dense_vnni" in schedule_rule:
+ schedule_dense(block, M, False, sch)
+
+ tune_rec = TuningRecord(sch.trace, [0.0], workload,
tvm.target.Target(target), [])
+
+ database.commit_tuning_record(tune_rec)
+
+ with ApplyHistoryBest(database):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ """
+ The log should say
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_expand_dims
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast_1
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_nn_batch_matmul
+
+ This means batch matmul and others are scheduled by TE, and dense
(the one not warned) is found in the
+ meta schedule tuning database during ApplyHistoryBest
+ """
+ lib = relay.build(relay_mod, target=target, params=params)
+
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ runtime.set_input("data", data)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+
+ np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+ tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc,
dot_product_intrin)
+
+ manual_tir_common(do_tune=False)
+
+ def schedule_rule_dense_vnni(sch, block):
+ schedule_dense(block, None, True, sch)
+ return [sch]
+
+ register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)
+
+ # TODO(masahi): Weird error from tuning with CheckSubtreeCompactDataflow
in for_kind.cc turned on
+ # manual_tir_common(do_tune=True)
Review comment:
It's not about tensorization, but parallelization. Now I have more
information about this error. If I turn on a detailed log, I got this from
`ReplayTraceNode`:
```
for i0_0_i1_0_fused_fused in T.serial(8192):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for i0_1 in T.serial(8):
for ax0_init in T.vectorized(16):
with T.block("compute_init"):
i = T.axis.spatial(1024, i0_0_i1_0_fused_fused // 64
* 8 + i0_1)
j = T.axis.spatial(1024, i0_0_i1_0_fused_fused % 64
* 16 + ax0_init)
T.reads()
T.writes(compute[i, j])
T.block_attr({"schedule_rule":"meta_schedule.dense_vnni",
"workload":["dense_vnni.x86", ["TENSOR", [1024, 1024], "uint8"], ["TENSOR",
[64, 256, 16, 4], "int8"], None, "int32"]})
compute[i, j] = 0
for ax1_0 in T.serial(256):
# tir.Block#1
with T.block("compute_update_o"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
i = T.axis.spatial(1024, i0_0_i1_0_fused_fused // 64
* 8 + i0_1)
j_o = T.axis.spatial(64, i0_0_i1_0_fused_fused % 64)
k_o = T.axis.reduce(256, ax1_0)
T.reads(compute[i, j_o * 16 : j_o * 16 + 16],
placeholder[i, k_o * 4 : k_o * 4 + 4], placeholder_1[j_o, k_o, 0 : 16, 0 : 4])
T.writes(compute[i, j_o * 16 : j_o * 16 + 16])
A = T.match_buffer(placeholder[i, k_o * 4 : k_o * 4
+ 4], [4], dtype="uint8", offset_factor=1)
B = T.match_buffer(placeholder_1[j_o, k_o, 0 : 16, 0
: 4], [16, 4], dtype="int8", offset_factor=1)
C = T.match_buffer(compute[i, j_o * 16 : j_o * 16 +
16], [16], dtype="int32", offset_factor=1)
A_u8x4: T.uint8x4 = A[T.ramp(0, 1, 4)]
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
B_i8x64: T.int8x64 = B[0, T.ramp(0, 1, 64)]
B_i32x16: T.int32x16 = T.reinterpret(B_i8x64,
dtype="int32x16")
C[T.ramp(0, 1, 16)] = C[T.ramp(0, 1, 16)] +
T.call_llvm_pure_intrin(9785, T.uint32(0), T.broadcast(0, 16),
T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16")
for i1_1 in T.vectorized(16):
with T.block("T_add_1"):
ax0 = T.axis.spatial(1024, i0_0_i1_0_fused_fused //
64 * 8 + i0_1)
ax1 = T.axis.spatial(1024, i0_0_i1_0_fused_fused %
64 * 16 + i1_1)
T.reads(compute[ax0, ax1], placeholder_2[0, ax1])
T.writes(T_add[ax0, ax1])
T_add[ax0, ax1] = compute[ax0, ax1] +
placeholder_2[0, ax1] + 1
Error message: The queried subtree root tir.For#0 in SRef tree does not have
compact dataflow, because its child block tir.Block#1 on SRef tree is neither a
local complete block nor a local reduction block.
```
So even though I'm doing `parellel` before `decompose_reduction` to
workaround that strict check, `ReplayTrace` is trying to apply `parallel` to a
schedule that is clearly already `decompose_reduction`-ed. This is very weird...
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]