masahi commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r835681478



##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +326,227 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
+@register
+def int32x16(imm, span):
+    return imm.astype("int32x16", span)
+
+
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+    post_blocks = sch.get_consumers(block)
+
+    if len(post_blocks) > 0:
+        while True:
+            next_post_blocks = []
+            for post_block in post_blocks:
+                next_consumers = sch.get_consumers(post_block)
+
+                if len(next_consumers) > 0:
+                    sch.compute_inline(post_block)
+
+                next_post_blocks += next_consumers
+
+            if len(next_post_blocks) == 0:
+                assert len(post_blocks) == 1
+                outer_block = post_blocks[0]
+                a_y, a_x = sch.get_loops(outer_block)[-2:]
+                break
+
+            post_blocks = next_post_blocks
+    else:
+        a_y, a_x, _ = sch.get_loops(block)[-3:]
+        outer_block = block
+
+    if do_tune:
+        y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+        a_yo, a_yi = sch.split(a_y, factors=y_factors)
+    else:
+        a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+    a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+    sch.reorder(a_yo, a_xo, a_yi, a_xi)
+    fused = sch.fuse(a_yo, a_xo)
+
+    if outer_block != block:
+        sch.vectorize(a_xi)
+        sch.compute_at(block, a_yi)
+
+    a_xi, a_k = sch.get_loops(block)[-2:]
+    a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+    sch.reorder(a_ko, a_xi, a_ki)
+
+    sch.parallel(fused)
+
+    dec = sch.decompose_reduction(block, a_ko)
+
+    init_loop = sch.get_loops(dec)[-1]
+    sch.vectorize(init_loop)
+
+    sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+    M, N, K = 1024, 1024, 1024
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    data_dtype = "uint8"
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+    # dense is tuned by the TIR schedule above, bmm is scheduled by TE 
(topi/x86/batch_matmul.py)
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+    out = relay.nn.batch_matmul(
+        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+        out_dtype="int32",
+    )
+
+    relay_mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake -num-cores 4"
+    dev = tvm.device(target, 0)
+
+    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    ref = (
+        relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+        .evaluate()(*[data, weight_np, bias_np])
+        .numpy()
+    )
+
+    params = {"weight": weight_np, "bias": bias_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name,
+            extracted_tasks,
+        )
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        if do_tune:
+            config = ReplayTraceConfig(
+                num_trials_per_iter=64,
+                num_trials_total=64,
+            )
+            database = tune_extracted_tasks(tune_tasks, target, config, 
work_dir=work_dir)
+        else:
+            database = JSONDatabase(
+                path_workload=osp.join(work_dir, "database_workload.json"),
+                path_tuning_record=osp.join(work_dir, 
"database_tuning_record.json"),
+            )
+
+            for task in tune_tasks:
+                mod = Parse._mod(task.dispatched[0])
+                workload = database.commit_workload(mod)
+
+                sch = tvm.tir.Schedule(mod)
+                block = sch.get_block("compute")
+                schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+                if "dense_vnni" in schedule_rule:
+                    schedule_dense(block, M, False, sch)
+
+                tune_rec = TuningRecord(sch.trace, [0.0], workload, 
tvm.target.Target(target), [])
+
+                database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            """
+            The log should say
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_expand_dims
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast_1
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_nn_batch_matmul
+
+            This means batch matmul and others are scheduled by TE, and dense 
(the one not warned) is found in the
+            meta schedule tuning database during ApplyHistoryBest
+            """
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+    tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, 
dot_product_intrin)
+
+    manual_tir_common(do_tune=False)
+
+    def schedule_rule_dense_vnni(sch, block):
+        schedule_dense(block, None, True, sch)
+        return [sch]
+
+    register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)
+
+    # TODO(masahi): Weird error from tuning with CheckSubtreeCompactDataflow 
in for_kind.cc turned on
+    # manual_tir_common(do_tune=True)

Review comment:
       It is solved, there is the default set of post processors for CPU that 
does additional `parallel` on top of my manual schedule 
https://github.com/apache/tvm/blob/7ff5c83f2191cdf5c4a9c5dbc752789c75f5dfa8/python/tvm/meta_schedule/tune.py#L121.
 Since I'm using manual schedule, I don't want any post procs. Disabling them 
fixed this issue, thanks @vinx13 @junrushao1994 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to