junrushao1994 commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r838937818
##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4,
atol=2e-4)
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+ for i in T.serial(0, 16):
+ with T.init():
+ C[i] = T.int32(0)
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x64 = B.vload([0, 0], dtype="int8x64")
+ B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+ C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this
is an update +=
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+ T.uint32(0),
+ T.int32x16(0),
+ T.broadcast(A_i32, 16),
+ B_i32x16,
+ dtype="int32x16",
+ )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+ post_blocks = sch.get_consumers(block)
+
+ if len(post_blocks) > 0:
+ while True:
+ next_post_blocks = []
+ for post_block in post_blocks:
+ next_consumers = sch.get_consumers(post_block)
+
+ if len(next_consumers) > 0:
+ sch.compute_inline(post_block)
+
+ next_post_blocks += next_consumers
+
+ if len(next_post_blocks) == 0:
+ assert len(post_blocks) == 1
+ outer_block = post_blocks[0]
+ a_y, a_x = sch.get_loops(outer_block)[-2:]
+ break
+
+ post_blocks = next_post_blocks
+ else:
+ a_y, a_x, _ = sch.get_loops(block)[-3:]
+ outer_block = block
+
+ if do_tune:
+ y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+ a_yo, a_yi = sch.split(a_y, factors=y_factors)
+ else:
+ a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+ a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+ sch.reorder(a_yo, a_xo, a_yi, a_xi)
+ fused = sch.fuse(a_yo, a_xo)
+
+ if outer_block != block:
+ sch.vectorize(a_xi)
+ sch.compute_at(block, a_yi)
+
+ a_xi, a_k = sch.get_loops(block)[-2:]
+ a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+ sch.reorder(a_ko, a_xi, a_ki)
+
+ sch.parallel(fused)
+ dec = sch.decompose_reduction(block, a_ko)
+
+ init_loop = sch.get_loops(dec)[-1]
+ sch.vectorize(init_loop)
+
+ sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+ M, N, K = 1024, 1024, 1024
+ data_shape = (M, K)
+ weight_shape = (N, K)
+
+ data_dtype = "uint8"
+ data = relay.var("data", shape=data_shape, dtype=data_dtype)
+ weight = relay.var("weight", shape=weight_shape, dtype="int8")
+ bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+ # dense is tuned by the TIR schedule above, bmm is scheduled by TE
(topi/x86/batch_matmul.py)
+ dense = relay.nn.dense(data, weight, out_dtype="int32")
+ bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+ out = relay.nn.batch_matmul(
+ relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+ relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+ out_dtype="int32",
+ )
+
+ relay_mod = tvm.IRModule.from_expr(out)
+
+ target = "llvm -mcpu=cascadelake -num-cores 4"
+ dev = tvm.device(target, 0)
+
+ data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+ weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+ bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+ ref = (
+ relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+ .evaluate()(*[data, weight_np, bias_np])
+ .numpy()
+ )
+
+ params = {"weight": weight_np, "bias": bias_np}
+
+ extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+ tune_tasks = list(
+ filter(
+ lambda task: "dense" in task.task_name,
+ extracted_tasks,
+ )
+ )
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ if do_tune:
+ config = ReplayTraceConfig(
+ num_trials_per_iter=64,
+ num_trials_total=64,
+ )
+ database = tune_extracted_tasks(
+ tune_tasks, target, config, work_dir=work_dir,
postprocs=lambda: []
+ )
+ else:
+ database = JSONDatabase(
+ path_workload=osp.join(work_dir, "database_workload.json"),
+ path_tuning_record=osp.join(work_dir,
"database_tuning_record.json"),
+ )
+
+ for task in tune_tasks:
+ mod = Parse._mod(task.dispatched[0])
+ workload = database.commit_workload(mod)
+
+ sch = tvm.tir.Schedule(mod)
+ block = sch.get_block("compute")
+ schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+ if "dense_vnni" in schedule_rule:
+ schedule_dense(block, M, False, sch)
+
+ tune_rec = TuningRecord(sch.trace, [0.0], workload,
tvm.target.Target(target), [])
+
+ database.commit_tuning_record(tune_rec)
+
+ with ApplyHistoryBest(database):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ """
+ The log should say
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_expand_dims
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast_1
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_nn_batch_matmul
+
+ This means batch matmul and others are scheduled by TE, and dense
(the one not warned) is found in the
+ meta schedule tuning database during ApplyHistoryBest
+ """
+ lib = relay.build(relay_mod, target=target, params=params)
+
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ runtime.set_input("data", data)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+
+ np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+ tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc,
dot_product_intrin)
+
+ manual_tir_common(do_tune=False)
+
+ def schedule_rule_dense_vnni(sch, block):
+ schedule_dense(block, None, True, sch)
+ return [sch]
+
+ register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)
Review comment:
@masahi Thanks for the comprehensive discussion about API design and
user intention!
Looks like we have 3 different proposals:
- P1. Always annotate schedule_rule even if it's not registered; Print a
warning if the schedule_rule isn't registered;
- P2. Always make sure schedule_rule exists when annotating, and error out
if it's not;
- P3. Always annotate schedule_rule even if it's not registered; No warning
if the schedule_rule isn't registered;
Masa and I both agree that P3 may not be ideal, because silently ignoring
abnormality may not be the best user experience.
> In the future when auto-tensorization is ready, I want to freely switch
between manual and automatic scheduling.
Totally agree about future possibility of switching around! Definitely it's
going to be a lot of fun :-)
Therefore, I would conclude that emitting a warning is probably both @masahi
and I agree on. Additionally, we might enhance the search space generation to
selectively check if the schedule_rule is allowed with a target-specific
allowlist, but it's probably not high-priority for now.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]