This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ff41c6156 [TIR][Schedule] Allow buffer name argument to 
Schedule.set_scope (#14327)
2ff41c6156 is described below

commit 2ff41c6156f5be311711c6235d94555dd8e5ab24
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Mar 17 23:10:20 2023 -0500

    [TIR][Schedule] Allow buffer name argument to Schedule.set_scope (#14327)
    
    Similar to other schedule primitives, allow the buffer argument to
    `Schedule.set_scope` to be passed as a human-readable name instead of
    a numeric index.  This change is solely in the Python API, and does
    not impact the internal representation of scheduling traces.
---
 python/tvm/tir/schedule/schedule.py                  | 8 +++++++-
 tests/python/unittest/test_tir_schedule_set_scope.py | 9 +++++----
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 896e2fc48e..d86cd86ea0 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2320,7 +2320,9 @@ class Schedule(Object):
         )
 
     @type_checked
-    def set_scope(self, block: Union[BlockRV, str], buffer_index: int, 
storage_scope: str) -> None:
+    def set_scope(
+        self, block: Union[BlockRV, str], buffer_index: Union[int, str, 
Buffer], storage_scope: str
+    ) -> None:
         """Set the storage scope of a buffer, where the buffer is
         specified by the a block and a write-index
 
@@ -2387,6 +2389,10 @@ class Schedule(Object):
         Set_scope requires the buffer to be an intermediate buffer defined via 
`alloc_buffer`.
         """
         block = self._normalize_block_arg(block)
+        if not isinstance(buffer_index, int):
+            _, buffer_index, _ = self._normalize_buffer_arg(
+                block, buffer_index, required_buffer_type="write"
+            )
         _ffi_api.ScheduleSetScope(  # type: ignore # pylint: disable=no-member
             self, block, buffer_index, storage_scope
         )
diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py 
b/tests/python/unittest/test_tir_schedule_set_scope.py
index e5fa25fbc3..40df049783 100644
--- a/tests/python/unittest/test_tir_schedule_set_scope.py
+++ b/tests/python/unittest/test_tir_schedule_set_scope.py
@@ -88,20 +88,21 @@ def element_wise_subregion_match_set_scope(A: 
T.Buffer((128, 128), "float32"), C
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, 
"block_name": True})
+use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False, 
"buffer_name": True})
 
-def test_set_scope(use_block_name):
+def test_set_scope(use_block_name, use_buffer_name):
     func = element_wise
     s = tir.Schedule(func, debug_mask='all')
-    s.set_scope('B' if use_block_name else s.get_block("B"), 0, "shared")
+    s.set_scope('B' if use_block_name else s.get_block("B"), 'B' if 
use_buffer_name else 0, "shared")
     tvm.ir.assert_structural_equal(element_wise_set_scope, s.mod["main"])
     verify_trace_roundtrip(sch=s, mod=func)
 
 
-def test_set_scope_fail_on_output_buffer(use_block_name):
+def test_set_scope_fail_on_output_buffer(use_block_name, use_buffer_name):
     func = element_wise
     s = tir.Schedule(func, debug_mask='all')
     with pytest.raises(tvm.tir.ScheduleError):
-        s.set_scope('C' if use_block_name else s.get_block("C"), 0, "shared")
+        s.set_scope('C' if use_block_name else s.get_block("C"), 'C' if 
use_buffer_name else 0, "shared")
 
 
 def test_set_scope_fail_on_index_out_of_bound():

Reply via email to