adstraw commented on code in PR #12954:
URL: https://github.com/apache/tvm/pull/12954#discussion_r986077448


##########
tests/python/contrib/test_hexagon/test_software_pipeline_async.py:
##########
@@ -46,18 +47,33 @@ def plus_one_ref(a):
     return plus_one_primfunc, plus_one_ref
 
 
[email protected]_hexagon
-def test_software_pipeline_with_cache_read(hexagon_launcher, compute, outer, 
inner, dtype, scope):
[email protected]
+def schedule(compute, sched, scope):
     sch = tir.Schedule(compute[0])
-    root = sch.get_block("root")
+
     compute_block = sch.get_block("compute")
     cache_read_block = sch.cache_read(compute_block, 0, scope)
 
     i, _ = sch.get_loops(compute_block)
     sch.compute_at(cache_read_block, i)
-    sch.annotate(i, "software_pipeline_stage", [0, 1])
-    sch.annotate(i, "software_pipeline_order", [0, 1])
-    sch.annotate(i, "software_pipeline_async_stages", [0])
+
+    if sched == "cache_read":

Review Comment:
   Good idea.  I can add this test.



##########
tests/python/contrib/test_hexagon/test_software_pipeline_async.py:
##########
@@ -46,18 +47,33 @@ def plus_one_ref(a):
     return plus_one_primfunc, plus_one_ref
 
 
[email protected]_hexagon
-def test_software_pipeline_with_cache_read(hexagon_launcher, compute, outer, 
inner, dtype, scope):
[email protected]
+def schedule(compute, sched, scope):
     sch = tir.Schedule(compute[0])
-    root = sch.get_block("root")
+
     compute_block = sch.get_block("compute")
     cache_read_block = sch.cache_read(compute_block, 0, scope)
 
     i, _ = sch.get_loops(compute_block)
     sch.compute_at(cache_read_block, i)
-    sch.annotate(i, "software_pipeline_stage", [0, 1])
-    sch.annotate(i, "software_pipeline_order", [0, 1])
-    sch.annotate(i, "software_pipeline_async_stages", [0])
+
+    if sched == "cache_read":

Review Comment:
   Good idea.  I can add this test.  Can either do it here or in a future PR.



##########
tests/python/contrib/test_hexagon/test_software_pipeline_async.py:
##########
@@ -26,8 +26,9 @@
 
 outer = tvm.testing.parameter(8, 16)
 inner = tvm.testing.parameter(64, 128)
-scope = tvm.testing.parameter("global", "global.vtcm")
 dtype = tvm.testing.parameter("uint8", "float16")
+scope = tvm.testing.parameter("global", "global.vtcm")
+sched = tvm.testing.parameter("cache_read", "cache_read_write")
 
 
 @tvm.testing.fixture

Review Comment:
   Correct.  This allows for any number of `cache_read` and `cache_write` 
stages to be lowered to Async DMA.  Note that there is a known issue when 
trying to do `cache_read` for an op with multiple inputs in the same stage 
which will be addressed in a future PR.  Future PR will modify compute on this 
test to be `a + b` instead of `a + 1` and add support to use Async DMA to 
`cache_read` both `a` and `b`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to