masahi commented on code in PR #13180:
URL: https://github.com/apache/tvm/pull/13180#discussion_r1003841498
##########
tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py:
##########
@@ -184,3 +187,177 @@ def test_resnet50(hexagon_launcher):
hexagon_lowered.get_graph_json(), hexagon_lowered.lib
)
print(debug_ex.profile(input_name=inp.copy()))
+
+
+def _schedule_packed_8x8x32_conv2d(do_tune: bool):
+ """Manually schedule a conv2d block, created from TE compute op via
CreatePrimFunc,
+ using 8x8x32 packed layout.
+ """
+
+ def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
+ if conv2d_block == None:
+ try:
+ conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+ except:
+ return False
+
+ assert "conv2d_NCHWc_int8" in
sch.get(conv2d_block).annotations["schedule_rule"]
+
+ # Apply scheduling
+
+ post_blocks = sch.get_consumers(conv2d_block)
+ if len(post_blocks) > 0:
+ # Fuse all intermediate post ops into the last op.
+ # This is equivalent to the traverse_inline function used in TE
schedules.
+ while True:
+ next_post_blocks = []
+ for post_block in post_blocks:
+ next_consumers = sch.get_consumers(post_block)
+ if len(next_consumers) > 0:
+ sch.compute_inline(post_block)
+ next_post_blocks += next_consumers
+ if len(next_post_blocks) == 0:
+ assert len(post_blocks) == 1
+ outer_block = post_blocks[0]
+ break
+ post_blocks = next_post_blocks
+ else:
+ outer_block = conv2d_block
+
+ # Move the conv2d mma into the injective post mma compute block
+ if outer_block != conv2d_block:
+ loops = sch.get_loops(outer_block)
+ # TODO(csullivan): Currently does all post conv2d mma steps
+ # directly after accumulation for one spatial pixel. May
+ # be desirable to do this with coarser spatial granularity
+ sch.compute_at(conv2d_block, loops[4])
+
+ def index_map_nchw32c_nchw8h8w32c(n, c, h, w, c32):
+ return [n, c, h // 8, w // 8, h % 8, w % 8, c32]
+
+ # Add cache for input and output activation layout transform,
+ # note that weight is already in correct layout
+ input_cache = sch.cache_read(conv2d_block, 0, "global")
+ output_cache = sch.cache_write(outer_block, 0, "global")
+ # Transform the layout of the input
+ sch.transform_layout(
+ conv2d_block, ("read", 0),
index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0
+ )
+ # Transform the layout of the int32 accumulator
+ sch.transform_layout(
+ conv2d_block, ("write", 0),
index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0
+ )
+ # Transform the layout of the output
+ sch.transform_layout(
+ outer_block, ("write", 0),
index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0
+ )
+ return True
+
+ return schedule_fn
+
+
+def tune_packed_8x8x32_template(mod, params, hexagon_launcher):
+ def schedule_rule_conv2d_packed_8x8x32(sch: Schedule, conv2d_block:
BlockRV):
+ _schedule_packed_8x8x32_conv2d(do_tune=True)(sch, conv2d_block)
+ return [sch]
+
+ # register_func("meta_schedule.conv2d_NCHWc_int8",
schedule_rule_conv2d_packed_8x8x32)
+
+ def schedule_conv2d_for_tune(sch: Schedule):
+ _schedule_packed_8x8x32_conv2d(do_tune=True)(sch)
+
+ # This line is necessary for link-params to take effect during
+ # task extraction and relay.build(...).
+ mod = mod.with_attr("executor", executor)
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ database = ms.relay_integration.tune_relay(
+ mod=mod,
+ target=target,
+ params=params,
+ work_dir=work_dir,
+ max_trials_global=20000,
+ max_trials_per_task=1,
+ num_trials_per_iter=1,
+ strategy="replay-trace",
+ builder=get_hexagon_local_builder(),
+ runner=get_hexagon_rpc_runner(hexagon_launcher, number=20),
+ # TODO(csullivan): Configrm the below is accurate
Review Comment:
@csullivan so shall we remove the TODO comment (which also has a typo lol)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]