masahi commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r836874713



##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+    post_blocks = sch.get_consumers(block)
+
+    if len(post_blocks) > 0:
+        while True:
+            next_post_blocks = []
+            for post_block in post_blocks:
+                next_consumers = sch.get_consumers(post_block)
+
+                if len(next_consumers) > 0:
+                    sch.compute_inline(post_block)
+
+                next_post_blocks += next_consumers
+
+            if len(next_post_blocks) == 0:
+                assert len(post_blocks) == 1
+                outer_block = post_blocks[0]
+                a_y, a_x = sch.get_loops(outer_block)[-2:]
+                break
+
+            post_blocks = next_post_blocks
+    else:
+        a_y, a_x, _ = sch.get_loops(block)[-3:]
+        outer_block = block
+
+    if do_tune:
+        y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+        a_yo, a_yi = sch.split(a_y, factors=y_factors)
+    else:
+        a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+    a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+    sch.reorder(a_yo, a_xo, a_yi, a_xi)
+    fused = sch.fuse(a_yo, a_xo)
+
+    if outer_block != block:
+        sch.vectorize(a_xi)
+        sch.compute_at(block, a_yi)
+
+    a_xi, a_k = sch.get_loops(block)[-2:]
+    a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+    sch.reorder(a_ko, a_xi, a_ki)
+
+    sch.parallel(fused)
+    dec = sch.decompose_reduction(block, a_ko)
+
+    init_loop = sch.get_loops(dec)[-1]
+    sch.vectorize(init_loop)
+
+    sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+    M, N, K = 1024, 1024, 1024
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    data_dtype = "uint8"
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+    # dense is tuned by the TIR schedule above, bmm is scheduled by TE 
(topi/x86/batch_matmul.py)
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+    out = relay.nn.batch_matmul(
+        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+        out_dtype="int32",
+    )
+
+    relay_mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake -num-cores 4"
+    dev = tvm.device(target, 0)
+
+    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    ref = (
+        relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+        .evaluate()(*[data, weight_np, bias_np])
+        .numpy()
+    )
+
+    params = {"weight": weight_np, "bias": bias_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name,
+            extracted_tasks,
+        )
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        if do_tune:
+            config = ReplayTraceConfig(
+                num_trials_per_iter=64,
+                num_trials_total=64,
+            )
+            database = tune_extracted_tasks(
+                tune_tasks, target, config, work_dir=work_dir, 
postprocs=lambda: []
+            )
+        else:
+            database = JSONDatabase(
+                path_workload=osp.join(work_dir, "database_workload.json"),
+                path_tuning_record=osp.join(work_dir, 
"database_tuning_record.json"),
+            )
+
+            for task in tune_tasks:
+                mod = Parse._mod(task.dispatched[0])
+                workload = database.commit_workload(mod)
+
+                sch = tvm.tir.Schedule(mod)
+                block = sch.get_block("compute")
+                schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+                if "dense_vnni" in schedule_rule:
+                    schedule_dense(block, M, False, sch)
+
+                tune_rec = TuningRecord(sch.trace, [0.0], workload, 
tvm.target.Target(target), [])
+
+                database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            """
+            The log should say
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_expand_dims
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast_1
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_nn_batch_matmul
+
+            This means batch matmul and others are scheduled by TE, and dense 
(the one not warned) is found in the
+            meta schedule tuning database during ApplyHistoryBest
+            """
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+    tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, 
dot_product_intrin)

Review comment:
       This test case is not a demonstration of TIR tensorization per se, so 
the mechanic of `tir.TensorIntrin.register` is not too relevant here. We have 
dedicated test cases for TIR tensorization and more uses of 
`tir.TensorIntrin.register` in 
https://github.com/apache/tvm/blob/ff3a48e9c0af1ef5c9395075858755ea5e3c93f5/tests/python/unittest/test_tir_schedule_tensorize.py#L459-L462
  




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to