nverke commented on code in PR #13352:
URL: https://github.com/apache/tvm/pull/13352#discussion_r1023285390


##########
tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py:
##########
@@ -359,11 +371,223 @@ def test_packed_8x8x32_resnet50(hexagon_launcher):
             params=params,
         )
 
-    with hexagon_launcher.start_session() as session:
+    with hexagon_launcher.create_session() as session:
+        graph_mod = session.get_executor_from_factory(hexagon_lowered)
+        graph_mod.set_input(input_name, inp.copy())
+        graph_mod.run()
+        hexagon_output = graph_mod.get_output(0).numpy()
+
+        llvm_graph_mod = 
tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
+        llvm_graph_mod.set_input(input_name, inp.copy())
+        llvm_graph_mod.run()
+        ref_result = llvm_graph_mod.get_output(0).numpy()
+
+
+def _schedule_async_dma_conv2d():
+    """Manually schedule a conv2d block, created from TE compute op via 
CreatePrimFunc,
+    using 8x8x32 packed layout.
+    """
+
+    def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
+        if conv2d_block is None:
+            try:
+                conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+            except:
+                return False
+
+        assert "conv2d_NCHWc_int8" in 
sch.get(conv2d_block).annotations["schedule_rule"]
+
+        # Apply scheduling
+
+        post_blocks = sch.get_consumers(conv2d_block)
+        if len(post_blocks) > 0:
+            # Fuse all intermediate post ops into the last op.
+            # This is equivalent to the traverse_inline function used in TE 
schedules.
+            while True:
+                next_post_blocks = []
+                for post_block in post_blocks:
+                    next_consumers = sch.get_consumers(post_block)
+                    if len(next_consumers) > 0:
+                        sch.compute_inline(post_block)
+                    next_post_blocks += next_consumers
+                if len(next_post_blocks) == 0:
+                    assert len(post_blocks) == 1
+                    outer_block = post_blocks[0]
+                    break
+                post_blocks = next_post_blocks
+        else:
+            outer_block = conv2d_block
+
+        # Move the conv2d mma into the injective post mma compute block
+        if outer_block != conv2d_block:
+            loops = sch.get_loops(outer_block)
+            # Compute at the second loop for pipelining.
+            sch.compute_at(conv2d_block, loops[1])
+
+        # Add cache for input and output for copying data to vtcm.
+        input_a_cache = sch.cache_read(conv2d_block, 0, "global.vtcm")
+        sch.compute_at(input_a_cache, sch.get_loops(conv2d_block)[1])
+        sch.fuse(*sch.get_loops(input_a_cache)[2:])
+
+        input_b_cache = sch.cache_read(conv2d_block, 1, "global.vtcm")
+        sch.compute_at(input_b_cache, sch.get_loops(conv2d_block)[1])
+        sch.fuse(*sch.get_loops(input_b_cache)[2:])
+
+        output_cache_write = sch.cache_write(conv2d_block, 0, "global.vtcm")
+        sch.fuse(*sch.get_loops(output_cache_write)[2:])
+
+        conv2d_loops = sch.get_loops(block=conv2d_block)
+        if len(conv2d_loops) == 8:
+            # Handle case where kernel is not 1x1
+            oc, x0, x1, ic = conv2d_loops[-4:]
+            ic_o, ic_i = sch.split(loop=ic, factors=[None, 4], 
preserve_unit_iters=True)
+            oc_o, oc_i = sch.split(loop=oc, factors=[None, 32], 
preserve_unit_iters=True)
+            sch.reorder(oc_o, x0, x1, ic_o, oc_i, ic_i)
+            new_loops = sch.get_loops(block=conv2d_block)
+            sch.parallel(new_loops[2])
+            sch.unroll(new_loops[-4])
+            # TODO(nverke): Add compute optimizations here.
+        else:
+            # Handle case where kernel is 1x1
+            oc, kh, kw, x0, x1, ic = conv2d_loops[-6:]
+            ic_o, ic_i = sch.split(loop=ic, factors=[None, 4], 
preserve_unit_iters=True)
+            oc_o, oc_i = sch.split(loop=oc, factors=[None, 32], 
preserve_unit_iters=True)
+            sch.reorder(oc_o, kh, kw, x0, x1, ic_o, oc_i, ic_i)
+            new_loops = sch.get_loops(block=conv2d_block)
+            sch.parallel(new_loops[2])
+            sch.unroll(new_loops[-4])
+            # TODO(nverke): Add compute optimizations here.
+        sch.blockize(loop=oc_i)
+
+        sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN)
+
+        pipeline_loop = conv2d_loops[1]
+        sch.annotate(pipeline_loop, "software_pipeline_stage", [0, 0, 1, 2, 3])
+        sch.annotate(pipeline_loop, "software_pipeline_order", [0, 1, 2, 3, 4])
+        sch.annotate(pipeline_loop, "software_pipeline_async_stages", [0, 2])
+
+        return True
+
+    return schedule_fn
+
+
+def tune_async_dma_template(mod, params, hexagon_launcher):
+    """Generate async dma template."""
+
+    def schedule_rule_conv2d_async_dma(sch: Schedule, conv2d_block: BlockRV):
+        _schedule_async_dma_conv2d()(sch, conv2d_block)
+        return [sch]
+
+    register_func(
+        "meta_schedule.conv2d_NCHWc_int8.async_dma.hexagon", 
schedule_rule_conv2d_async_dma
+    )
+
+    def schedule_conv2d_for_tune(sch: Schedule):
+        _schedule_async_dma_conv2d()(sch)
+
+    # This line is necessary for link-params to take effect during
+    # task extraction and relay.build(...).
+    mod = mod.with_attr("executor", EXECUTOR)
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        database = ms.relay_integration.tune_relay(
+            mod=mod,
+            target=TARGET_HEXAGON,
+            params=params,
+            work_dir=work_dir,
+            max_trials_global=20000,
+            max_trials_per_task=1,
+            num_trials_per_iter=1,
+            strategy="replay-trace",
+            builder=get_hexagon_local_builder(
+                tvm.transform.PassContext(
+                    opt_level=3,
+                    config={"tir.use_async_copy": 1, 
"tir.merge_async_commit_queue_scope": 0},
+                )
+            ),
+            runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
+            # Constrain search space to only be the single
+            # schedule provided for all blocks. No auto
+            # scheduling will be possible.
+            space=ms.space_generator.ScheduleFn(
+                schedule_conv2d_for_tune,
+                sch_rules=[],
+                postprocs=[],
+                mutator_probs={},
+            ),
+            # Without this, the same workloads with different constant weights
+            # are treated as distinct tuning tasks.
+            module_equality="ignore-ndarray",
+        )
+        return ms.relay_integration.compile_relay(
+            database=database,
+            mod=mod,
+            target=TARGET_HEXAGON,
+            params=params,
+            pass_config={
+                "tir.use_async_copy": 1,
+                "tir.merge_async_commit_queue_scope": False,
+            },
+        )
+
+
[email protected]_hexagon
+def test_async_dma_resnet50(hexagon_launcher):

Review Comment:
   Was able to unify a good amount of the code ✅



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to