masahi commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r837893607
##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4,
atol=2e-4)
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+ for i in T.serial(0, 16):
+ with T.init():
+ C[i] = T.int32(0)
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x64 = B.vload([0, 0], dtype="int8x64")
+ B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+ C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this
is an update +=
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+ T.uint32(0),
+ T.int32x16(0),
+ T.broadcast(A_i32, 16),
+ B_i32x16,
+ dtype="int32x16",
+ )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+ post_blocks = sch.get_consumers(block)
+
+ if len(post_blocks) > 0:
+ while True:
+ next_post_blocks = []
+ for post_block in post_blocks:
+ next_consumers = sch.get_consumers(post_block)
+
+ if len(next_consumers) > 0:
+ sch.compute_inline(post_block)
+
+ next_post_blocks += next_consumers
+
+ if len(next_post_blocks) == 0:
+ assert len(post_blocks) == 1
+ outer_block = post_blocks[0]
+ a_y, a_x = sch.get_loops(outer_block)[-2:]
+ break
+
+ post_blocks = next_post_blocks
+ else:
+ a_y, a_x, _ = sch.get_loops(block)[-3:]
+ outer_block = block
+
+ if do_tune:
+ y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+ a_yo, a_yi = sch.split(a_y, factors=y_factors)
+ else:
+ a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+ a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+ sch.reorder(a_yo, a_xo, a_yi, a_xi)
+ fused = sch.fuse(a_yo, a_xo)
+
+ if outer_block != block:
+ sch.vectorize(a_xi)
+ sch.compute_at(block, a_yi)
+
+ a_xi, a_k = sch.get_loops(block)[-2:]
+ a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+ sch.reorder(a_ko, a_xi, a_ki)
+
+ sch.parallel(fused)
+ dec = sch.decompose_reduction(block, a_ko)
+
+ init_loop = sch.get_loops(dec)[-1]
+ sch.vectorize(init_loop)
+
+ sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+ M, N, K = 1024, 1024, 1024
+ data_shape = (M, K)
+ weight_shape = (N, K)
+
+ data_dtype = "uint8"
+ data = relay.var("data", shape=data_shape, dtype=data_dtype)
+ weight = relay.var("weight", shape=weight_shape, dtype="int8")
+ bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+ # dense is tuned by the TIR schedule above, bmm is scheduled by TE
(topi/x86/batch_matmul.py)
+ dense = relay.nn.dense(data, weight, out_dtype="int32")
+ bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+ out = relay.nn.batch_matmul(
+ relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+ relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+ out_dtype="int32",
+ )
+
+ relay_mod = tvm.IRModule.from_expr(out)
+
+ target = "llvm -mcpu=cascadelake -num-cores 4"
+ dev = tvm.device(target, 0)
+
+ data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+ weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+ bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+ ref = (
+ relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+ .evaluate()(*[data, weight_np, bias_np])
+ .numpy()
+ )
+
+ params = {"weight": weight_np, "bias": bias_np}
+
+ extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+ tune_tasks = list(
+ filter(
+ lambda task: "dense" in task.task_name,
+ extracted_tasks,
+ )
+ )
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ if do_tune:
+ config = ReplayTraceConfig(
+ num_trials_per_iter=64,
+ num_trials_total=64,
+ )
+ database = tune_extracted_tasks(
+ tune_tasks, target, config, work_dir=work_dir,
postprocs=lambda: []
+ )
+ else:
+ database = JSONDatabase(
+ path_workload=osp.join(work_dir, "database_workload.json"),
+ path_tuning_record=osp.join(work_dir,
"database_tuning_record.json"),
+ )
+
+ for task in tune_tasks:
+ mod = Parse._mod(task.dispatched[0])
+ workload = database.commit_workload(mod)
+
+ sch = tvm.tir.Schedule(mod)
+ block = sch.get_block("compute")
+ schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+ if "dense_vnni" in schedule_rule:
+ schedule_dense(block, M, False, sch)
+
+ tune_rec = TuningRecord(sch.trace, [0.0], workload,
tvm.target.Target(target), [])
+
+ database.commit_tuning_record(tune_rec)
+
+ with ApplyHistoryBest(database):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ """
+ The log should say
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_expand_dims
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast_1
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_nn_batch_matmul
+
+ This means batch matmul and others are scheduled by TE, and dense
(the one not warned) is found in the
+ meta schedule tuning database during ApplyHistoryBest
+ """
+ lib = relay.build(relay_mod, target=target, params=params)
+
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ runtime.set_input("data", data)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+
+ np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+ tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc,
dot_product_intrin)
+
+ manual_tir_common(do_tune=False)
+
+ def schedule_rule_dense_vnni(sch, block):
+ schedule_dense(block, None, True, sch)
+ return [sch]
+
+ register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)
Review comment:
I want to freely add `schedule_rule` annotations to various TE compute
to experiment with things, so requiring that all `schedule_rule` annotations to
have the corresponding packed func registered sounds like a heavy-weight
requirement to me. If we have TOPI equivalent for TIR manual schedules, such
requirement is easy to satisfy, but until then I expect that manual TIR
scheduling is done in an one-off fashion like this PR.
Also, it is totally reasonable to want to auto-schedule TE compute annotated
with `schedule_rule`. Currently I annotated TE x86 `dense` and `batch_matmul`
compute with VNNI-specific schedule rules (like `meta_schedule.dense_vnni`
above) to apply my manual TIR schedule, but that prevents any automatic
scheduling from happening on these TE compute. In the future when
auto-tensorization is ready, I want to freely switch between manual and
automatic scheduling.
So I want "the need to annotate `schedule_relu`" and "whether or not I want
to register my custom schedule rule" be decoupled.
"Silent failing" is certainly something we need be mindful of. When we
encounter a block with `schedule_relu` annotation, and if the schedule rule
registration is missing, how about emitting a warning to make sure that a user
is aware of the fact?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]