This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2ff41c6156 [TIR][Schedule] Allow buffer name argument to
Schedule.set_scope (#14327)
2ff41c6156 is described below
commit 2ff41c6156f5be311711c6235d94555dd8e5ab24
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Mar 17 23:10:20 2023 -0500
[TIR][Schedule] Allow buffer name argument to Schedule.set_scope (#14327)
Similar to other schedule primitives, allow the buffer argument to
`Schedule.set_scope` to be passed as a human-readable name instead of
a numeric index. This change is solely in the Python API, and does
not impact the internal representation of scheduling traces.
---
python/tvm/tir/schedule/schedule.py | 8 +++++++-
tests/python/unittest/test_tir_schedule_set_scope.py | 9 +++++----
2 files changed, 12 insertions(+), 5 deletions(-)
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 896e2fc48e..d86cd86ea0 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2320,7 +2320,9 @@ class Schedule(Object):
)
@type_checked
- def set_scope(self, block: Union[BlockRV, str], buffer_index: int,
storage_scope: str) -> None:
+ def set_scope(
+ self, block: Union[BlockRV, str], buffer_index: Union[int, str,
Buffer], storage_scope: str
+ ) -> None:
"""Set the storage scope of a buffer, where the buffer is
specified by the a block and a write-index
@@ -2387,6 +2389,10 @@ class Schedule(Object):
Set_scope requires the buffer to be an intermediate buffer defined via
`alloc_buffer`.
"""
block = self._normalize_block_arg(block)
+ if not isinstance(buffer_index, int):
+ _, buffer_index, _ = self._normalize_buffer_arg(
+ block, buffer_index, required_buffer_type="write"
+ )
_ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member
self, block, buffer_index, storage_scope
)
diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py
b/tests/python/unittest/test_tir_schedule_set_scope.py
index e5fa25fbc3..40df049783 100644
--- a/tests/python/unittest/test_tir_schedule_set_scope.py
+++ b/tests/python/unittest/test_tir_schedule_set_scope.py
@@ -88,20 +88,21 @@ def element_wise_subregion_match_set_scope(A:
T.Buffer((128, 128), "float32"), C
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False,
"block_name": True})
+use_buffer_name = tvm.testing.parameter(by_dict={"buffer_index": False,
"buffer_name": True})
-def test_set_scope(use_block_name):
+def test_set_scope(use_block_name, use_buffer_name):
func = element_wise
s = tir.Schedule(func, debug_mask='all')
- s.set_scope('B' if use_block_name else s.get_block("B"), 0, "shared")
+ s.set_scope('B' if use_block_name else s.get_block("B"), 'B' if
use_buffer_name else 0, "shared")
tvm.ir.assert_structural_equal(element_wise_set_scope, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
-def test_set_scope_fail_on_output_buffer(use_block_name):
+def test_set_scope_fail_on_output_buffer(use_block_name, use_buffer_name):
func = element_wise
s = tir.Schedule(func, debug_mask='all')
with pytest.raises(tvm.tir.ScheduleError):
- s.set_scope('C' if use_block_name else s.get_block("C"), 0, "shared")
+ s.set_scope('C' if use_block_name else s.get_block("C"), 'C' if
use_buffer_name else 0, "shared")
def test_set_scope_fail_on_index_out_of_bound():