jwfromm commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r836741636



##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+    post_blocks = sch.get_consumers(block)
+
+    if len(post_blocks) > 0:
+        while True:
+            next_post_blocks = []
+            for post_block in post_blocks:
+                next_consumers = sch.get_consumers(post_block)
+
+                if len(next_consumers) > 0:
+                    sch.compute_inline(post_block)
+
+                next_post_blocks += next_consumers
+
+            if len(next_post_blocks) == 0:
+                assert len(post_blocks) == 1
+                outer_block = post_blocks[0]
+                a_y, a_x = sch.get_loops(outer_block)[-2:]
+                break
+
+            post_blocks = next_post_blocks
+    else:
+        a_y, a_x, _ = sch.get_loops(block)[-3:]
+        outer_block = block
+
+    if do_tune:
+        y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+        a_yo, a_yi = sch.split(a_y, factors=y_factors)
+    else:
+        a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+    a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+    sch.reorder(a_yo, a_xo, a_yi, a_xi)
+    fused = sch.fuse(a_yo, a_xo)
+
+    if outer_block != block:
+        sch.vectorize(a_xi)
+        sch.compute_at(block, a_yi)
+
+    a_xi, a_k = sch.get_loops(block)[-2:]
+    a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+    sch.reorder(a_ko, a_xi, a_ki)
+
+    sch.parallel(fused)
+    dec = sch.decompose_reduction(block, a_ko)
+
+    init_loop = sch.get_loops(dec)[-1]
+    sch.vectorize(init_loop)
+
+    sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+    M, N, K = 1024, 1024, 1024
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    data_dtype = "uint8"
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+    # dense is tuned by the TIR schedule above, bmm is scheduled by TE 
(topi/x86/batch_matmul.py)
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+    out = relay.nn.batch_matmul(
+        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+        out_dtype="int32",
+    )
+
+    relay_mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake -num-cores 4"
+    dev = tvm.device(target, 0)
+
+    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    ref = (
+        relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+        .evaluate()(*[data, weight_np, bias_np])
+        .numpy()
+    )
+
+    params = {"weight": weight_np, "bias": bias_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name,
+            extracted_tasks,
+        )
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        if do_tune:
+            config = ReplayTraceConfig(
+                num_trials_per_iter=64,
+                num_trials_total=64,
+            )
+            database = tune_extracted_tasks(
+                tune_tasks, target, config, work_dir=work_dir, 
postprocs=lambda: []
+            )
+        else:
+            database = JSONDatabase(
+                path_workload=osp.join(work_dir, "database_workload.json"),
+                path_tuning_record=osp.join(work_dir, 
"database_tuning_record.json"),
+            )
+
+            for task in tune_tasks:
+                mod = Parse._mod(task.dispatched[0])
+                workload = database.commit_workload(mod)
+
+                sch = tvm.tir.Schedule(mod)
+                block = sch.get_block("compute")
+                schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+                if "dense_vnni" in schedule_rule:
+                    schedule_dense(block, M, False, sch)
+
+                tune_rec = TuningRecord(sch.trace, [0.0], workload, 
tvm.target.Target(target), [])
+
+                database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            """
+            The log should say
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_expand_dims
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast_1
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_nn_batch_matmul
+
+            This means batch matmul and others are scheduled by TE, and dense 
(the one not warned) is found in the
+            meta schedule tuning database during ApplyHistoryBest
+            """
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+    tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, 
dot_product_intrin)
+
+    manual_tir_common(do_tune=False)
+
+    def schedule_rule_dense_vnni(sch, block):
+        schedule_dense(block, None, True, sch)
+        return [sch]
+
+    register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)

Review comment:
       Is this overwriting the default behavior for dense on cascadelake? If 
so, a quick comment saying so wouldnt hurt.

##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +326,227 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
+@register
+def int32x16(imm, span):
+    return imm.astype("int32x16", span)
+
+
[email protected]_func

Review comment:
       Are these generally useful intrinsics or just needed for the tests in 
this file? If it's just test intrinsics it makes sense to keep here.

##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+    post_blocks = sch.get_consumers(block)
+
+    if len(post_blocks) > 0:
+        while True:
+            next_post_blocks = []
+            for post_block in post_blocks:
+                next_consumers = sch.get_consumers(post_block)
+
+                if len(next_consumers) > 0:
+                    sch.compute_inline(post_block)
+
+                next_post_blocks += next_consumers
+
+            if len(next_post_blocks) == 0:
+                assert len(post_blocks) == 1
+                outer_block = post_blocks[0]
+                a_y, a_x = sch.get_loops(outer_block)[-2:]
+                break
+
+            post_blocks = next_post_blocks
+    else:
+        a_y, a_x, _ = sch.get_loops(block)[-3:]
+        outer_block = block
+
+    if do_tune:
+        y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+        a_yo, a_yi = sch.split(a_y, factors=y_factors)
+    else:
+        a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+    a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+    sch.reorder(a_yo, a_xo, a_yi, a_xi)
+    fused = sch.fuse(a_yo, a_xo)
+
+    if outer_block != block:
+        sch.vectorize(a_xi)
+        sch.compute_at(block, a_yi)
+
+    a_xi, a_k = sch.get_loops(block)[-2:]
+    a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+    sch.reorder(a_ko, a_xi, a_ki)
+
+    sch.parallel(fused)
+    dec = sch.decompose_reduction(block, a_ko)
+
+    init_loop = sch.get_loops(dec)[-1]
+    sch.vectorize(init_loop)
+
+    sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+    M, N, K = 1024, 1024, 1024
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    data_dtype = "uint8"
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+    # dense is tuned by the TIR schedule above, bmm is scheduled by TE 
(topi/x86/batch_matmul.py)
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+    out = relay.nn.batch_matmul(
+        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+        out_dtype="int32",
+    )
+
+    relay_mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake -num-cores 4"
+    dev = tvm.device(target, 0)
+
+    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    ref = (
+        relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+        .evaluate()(*[data, weight_np, bias_np])
+        .numpy()
+    )
+
+    params = {"weight": weight_np, "bias": bias_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name,
+            extracted_tasks,
+        )
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        if do_tune:
+            config = ReplayTraceConfig(
+                num_trials_per_iter=64,
+                num_trials_total=64,
+            )
+            database = tune_extracted_tasks(
+                tune_tasks, target, config, work_dir=work_dir, 
postprocs=lambda: []
+            )
+        else:
+            database = JSONDatabase(
+                path_workload=osp.join(work_dir, "database_workload.json"),
+                path_tuning_record=osp.join(work_dir, 
"database_tuning_record.json"),
+            )
+
+            for task in tune_tasks:
+                mod = Parse._mod(task.dispatched[0])
+                workload = database.commit_workload(mod)
+
+                sch = tvm.tir.Schedule(mod)
+                block = sch.get_block("compute")
+                schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+                if "dense_vnni" in schedule_rule:
+                    schedule_dense(block, M, False, sch)
+
+                tune_rec = TuningRecord(sch.trace, [0.0], workload, 
tvm.target.Target(target), [])
+
+                database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            """
+            The log should say
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_expand_dims
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast_1
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_nn_batch_matmul
+
+            This means batch matmul and others are scheduled by TE, and dense 
(the one not warned) is found in the
+            meta schedule tuning database during ApplyHistoryBest
+            """
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    np.testing.assert_equal(out, ref)
+
+
[email protected]("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+    tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, 
dot_product_intrin)

Review comment:
       I think this is one of the most important lines in the whole file. I'd 
love to see a quick comment explaining what it does / how it works for future 
readers trying to figure out autotensorization.

##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):

Review comment:
       Adding a quick comment to each of these functions explaining what 
they're for would make this a lot easier to read.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to