masahi commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r835661633
##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +326,227 @@ def get_output(data, lib):
assert np.allclose(actual_output, expected_output, rtol=1e-4,
atol=2e-4)
+@register
+def int32x16(imm, span):
+ return imm.astype("int32x16", span)
+
+
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+ for i in T.serial(0, 16):
+ with T.init():
+ C[i] = T.int32(0)
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk],
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+
+ A_u8x4 = A.vload([0], "uint8x4")
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x64 = B.vload([0, 0], dtype="int8x64")
+ B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+ C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this
is an update +=
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+ T.uint32(0),
+ T.int32x16(0),
+ T.broadcast(A_i32, 16),
+ B_i32x16,
+ dtype="int32x16",
+ )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+ post_blocks = sch.get_consumers(block)
+
+ if len(post_blocks) > 0:
+ while True:
+ next_post_blocks = []
+ for post_block in post_blocks:
+ next_consumers = sch.get_consumers(post_block)
+
+ if len(next_consumers) > 0:
+ sch.compute_inline(post_block)
+
+ next_post_blocks += next_consumers
+
+ if len(next_post_blocks) == 0:
+ assert len(post_blocks) == 1
+ outer_block = post_blocks[0]
+ a_y, a_x = sch.get_loops(outer_block)[-2:]
+ break
+
+ post_blocks = next_post_blocks
+ else:
+ a_y, a_x, _ = sch.get_loops(block)[-3:]
+ outer_block = block
+
+ if do_tune:
+ y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+ a_yo, a_yi = sch.split(a_y, factors=y_factors)
+ else:
+ a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+ a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+ sch.reorder(a_yo, a_xo, a_yi, a_xi)
+ fused = sch.fuse(a_yo, a_xo)
+
+ if outer_block != block:
+ sch.vectorize(a_xi)
+ sch.compute_at(block, a_yi)
+
+ a_xi, a_k = sch.get_loops(block)[-2:]
+ a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+ sch.reorder(a_ko, a_xi, a_ki)
+
+ sch.parallel(fused)
+
+ dec = sch.decompose_reduction(block, a_ko)
+
+ init_loop = sch.get_loops(dec)[-1]
+ sch.vectorize(init_loop)
+
+ sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+ M, N, K = 1024, 1024, 1024
+ data_shape = (M, K)
+ weight_shape = (N, K)
+
+ data_dtype = "uint8"
+ data = relay.var("data", shape=data_shape, dtype=data_dtype)
+ weight = relay.var("weight", shape=weight_shape, dtype="int8")
+ bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+ # dense is tuned by the TIR schedule above, bmm is scheduled by TE
(topi/x86/batch_matmul.py)
+ dense = relay.nn.dense(data, weight, out_dtype="int32")
+ bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+ out = relay.nn.batch_matmul(
+ relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+ relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+ out_dtype="int32",
+ )
+
+ relay_mod = tvm.IRModule.from_expr(out)
+
+ target = "llvm -mcpu=cascadelake -num-cores 4"
+ dev = tvm.device(target, 0)
+
+ data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+ weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+ bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+ ref = (
+ relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+ .evaluate()(*[data, weight_np, bias_np])
+ .numpy()
+ )
+
+ params = {"weight": weight_np, "bias": bias_np}
+
+ extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+ tune_tasks = list(
+ filter(
+ lambda task: "dense" in task.task_name,
+ extracted_tasks,
+ )
+ )
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ if do_tune:
+ config = ReplayTraceConfig(
+ num_trials_per_iter=64,
+ num_trials_total=64,
+ )
+ database = tune_extracted_tasks(tune_tasks, target, config,
work_dir=work_dir)
+ else:
+ database = JSONDatabase(
+ path_workload=osp.join(work_dir, "database_workload.json"),
+ path_tuning_record=osp.join(work_dir,
"database_tuning_record.json"),
+ )
+
+ for task in tune_tasks:
+ mod = Parse._mod(task.dispatched[0])
+ workload = database.commit_workload(mod)
+
+ sch = tvm.tir.Schedule(mod)
+ block = sch.get_block("compute")
+ schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+ if "dense_vnni" in schedule_rule:
+ schedule_dense(block, M, False, sch)
+
+ tune_rec = TuningRecord(sch.trace, [0.0], workload,
tvm.target.Target(target), [])
+
+ database.commit_tuning_record(tune_rec)
+
+ with ApplyHistoryBest(database):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ """
+ The log should say
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_expand_dims
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_cast_1
+ meta_schedule/integration.cc:146: Warning: Cannot find workload:
tvmgen_default_fused_nn_batch_matmul
+
+ This means batch matmul and others are scheduled by TE, and dense
(the one not warned) is found in the
+ meta schedule tuning database during ApplyHistoryBest
+ """
+ lib = relay.build(relay_mod, target=target, params=params)
+
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ runtime.set_input("data", data)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+
+ np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+ tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc,
dot_product_intrin)
+
+ manual_tir_common(do_tune=False)
+
+ def schedule_rule_dense_vnni(sch, block):
+ schedule_dense(block, None, True, sch)
+ return [sch]
+
+ register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)
+
+ # TODO(masahi): Weird error from tuning with CheckSubtreeCompactDataflow
in for_kind.cc turned on
+ # manual_tir_common(do_tune=True)
Review comment:
Currently this results in the following error:
```
[08:11:02]
/home/masa/projects/dev/tvm/src/meta_schedule/task_scheduler/task_scheduler.cc:127:
Scheduler picks Task #0: "fused_nn_contrib_dense_pack_add_add"
Traceback (most recent call last):
File "test_meta_schedule_tune_relay.py", line 559, in <module>
test_tune_relay_manual_tir_vnni()
File "test_meta_schedule_tune_relay.py", line 547, in
test_tune_relay_manual_tir_vnni
manual_tir_common(do_tune=True)
File "test_meta_schedule_tune_relay.py", line 485, in manual_tir_common
database = tune_extracted_tasks(tune_tasks, target, config,
work_dir=work_dir)
File "/home/masa/projects/dev/tvm/python/tvm/meta_schedule/tune.py", line
716, in tune_extracted_tasks
task_scheduler.tune()
File
"/home/masa/projects/dev/tvm/python/tvm/meta_schedule/task_scheduler/task_scheduler.py",
line 60, in tune
_ffi_api.TaskSchedulerTune(self) # type: ignore # pylint:
disable=no-member
File "/home/masa/projects/dev/tvm/python/tvm/_ffi/_ctypes/packed_func.py",
line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
5: TVMFuncCall
4:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void
(tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler,
tvm::meta_schedule::TaskSchedulerNode, void, , void>(void
(tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1},
std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>
>)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>
>::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char,
std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
3: tvm::meta_schedule::TaskSchedulerNode::Tune()
2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void
(int, int)> const&) [clone .cold]
File "/home/masa/projects/dev/tvm/src/support/parallel_for.cc", line 128
RuntimeError: parallel_for_dynamic error with ScheduleError: (not rendered)
```
If I remove
https://github.com/apache/tvm/blob/0ddaaa6a7d1009ea7ca8313a51eb19abb8ae7699/src/tir/schedule/primitive/for_kind.cc#L160
it works.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]