masahi commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r836859198



##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+    post_blocks = sch.get_consumers(block)
+
+    if len(post_blocks) > 0:
+        while True:
+            next_post_blocks = []
+            for post_block in post_blocks:
+                next_consumers = sch.get_consumers(post_block)
+
+                if len(next_consumers) > 0:
+                    sch.compute_inline(post_block)
+
+                next_post_blocks += next_consumers
+
+            if len(next_post_blocks) == 0:
+                assert len(post_blocks) == 1
+                outer_block = post_blocks[0]
+                a_y, a_x = sch.get_loops(outer_block)[-2:]
+                break
+
+            post_blocks = next_post_blocks
+    else:
+        a_y, a_x, _ = sch.get_loops(block)[-3:]
+        outer_block = block
+
+    if do_tune:
+        y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+        a_yo, a_yi = sch.split(a_y, factors=y_factors)
+    else:
+        a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+    a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+    sch.reorder(a_yo, a_xo, a_yi, a_xi)
+    fused = sch.fuse(a_yo, a_xo)
+
+    if outer_block != block:
+        sch.vectorize(a_xi)
+        sch.compute_at(block, a_yi)
+
+    a_xi, a_k = sch.get_loops(block)[-2:]
+    a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+    sch.reorder(a_ko, a_xi, a_ki)
+
+    sch.parallel(fused)
+    dec = sch.decompose_reduction(block, a_ko)
+
+    init_loop = sch.get_loops(dec)[-1]
+    sch.vectorize(init_loop)
+
+    sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+    M, N, K = 1024, 1024, 1024
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    data_dtype = "uint8"
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+    # dense is tuned by the TIR schedule above, bmm is scheduled by TE 
(topi/x86/batch_matmul.py)
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+    out = relay.nn.batch_matmul(
+        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+        out_dtype="int32",
+    )
+
+    relay_mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake -num-cores 4"
+    dev = tvm.device(target, 0)
+
+    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    ref = (
+        relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+        .evaluate()(*[data, weight_np, bias_np])
+        .numpy()
+    )
+
+    params = {"weight": weight_np, "bias": bias_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name,
+            extracted_tasks,
+        )
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        if do_tune:
+            config = ReplayTraceConfig(
+                num_trials_per_iter=64,
+                num_trials_total=64,
+            )
+            database = tune_extracted_tasks(
+                tune_tasks, target, config, work_dir=work_dir, 
postprocs=lambda: []
+            )
+        else:
+            database = JSONDatabase(
+                path_workload=osp.join(work_dir, "database_workload.json"),
+                path_tuning_record=osp.join(work_dir, 
"database_tuning_record.json"),
+            )
+
+            for task in tune_tasks:
+                mod = Parse._mod(task.dispatched[0])
+                workload = database.commit_workload(mod)
+
+                sch = tvm.tir.Schedule(mod)
+                block = sch.get_block("compute")
+                schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+                if "dense_vnni" in schedule_rule:
+                    schedule_dense(block, M, False, sch)
+
+                tune_rec = TuningRecord(sch.trace, [0.0], workload, 
tvm.target.Target(target), [])
+
+                database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            """
+            The log should say
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_expand_dims
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast_1
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_nn_batch_matmul
+
+            This means batch matmul and others are scheduled by TE, and dense 
(the one not warned) is found in the
+            meta schedule tuning database during ApplyHistoryBest
+            """
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+    tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, 
dot_product_intrin)
+
+    manual_tir_common(do_tune=False)
+
+    def schedule_rule_dense_vnni(sch, block):
+        schedule_dense(block, None, True, sch)
+        return [sch]
+
+    register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)

Review comment:
       It is not "overwriting" in the sense that (1) TE scheduling is not 
affected and (2) without this registration, meta schedule tuning fails with 
   ```
     File 
"/home/masa/projects/dev/tvm/src/meta_schedule/space_generator/post_order_apply.cc",
 line 149
   ValueError: Check failed: (f) is false: Custom schedule rule not found: 
meta_schedule.dense_vnni
   ```
   
   This is because, for all TE compute annotated with `schedule_rule` like 
https://github.com/apache/tvm/blob/ce335c3a74185df6cc1152e53c60695d8a418d8e/python/tvm/topi/x86/dense.py#L299,
 we currently require the corresponding schedule rule to be registered (which 
this line does).
   
   Thinking about this now, I wonder if failing is the desired behavior. Since 
if we don't find the custom schedule registered, we can ignore the 
`schedule_relu` annotation and apply automatic scheduling. cc @junrushao1994 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to