This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fa8a9f7aaa [TIR][USMP] Preserve DeclBuffer in 
PoolAllocationToOffsetConverter (#15044)
fa8a9f7aaa is described below

commit fa8a9f7aaa752cfb71665d52d6c03de1050d81cb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 18:17:03 2023 -0400

    [TIR][USMP] Preserve DeclBuffer in PoolAllocationToOffsetConverter (#15044)
    
    Previously, `PoolAllocationToOffsetConverter` did not remap buffer
    objects occurring in `DeclBuffer` nodes.  This commit updates
    `PoolAllocationToOffsetConverter` to handle `DeclBuffer` nodes. This
    is a subset of changes, being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
---
 .../convert_pool_allocations_to_offsets.cc         |  11 +
 ...ransform_convert_pool_allocations_to_offsets.py | 255 +++++++++++++--------
 2 files changed, 173 insertions(+), 93 deletions(-)

diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc 
b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
index 439e264338..45d060567c 100644
--- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
+++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
@@ -99,6 +99,7 @@ class PoolAllocationToOffsetConverter : public 
StmtExprMutator {
   PrimExpr VisitExpr_(const VarNode* op) override;
   PrimExpr VisitExpr_(const BufferLoadNode* op) override;
   Stmt VisitStmt_(const BufferStoreNode* op) override;
+  Stmt VisitStmt_(const DeclBufferNode* op) override;
 
   Stmt VisitStmt_(const AllocateConstNode* op) override;
   LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& 
buffer_var, const Stmt& body);
@@ -386,6 +387,16 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const 
BufferStoreNode* op) {
   return std::move(store);
 }
 
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const DeclBufferNode* op) {
+  auto decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+
+  Buffer remapped = GetRemappedBuffer(decl->buffer);
+  if (!op->buffer.same_as(remapped)) {
+    decl.CopyOnWrite()->buffer = remapped;
+  }
+  return std::move(decl);
+}
+
 PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) 
{
   BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
 
diff --git 
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
 
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
index d0403fcae9..03929c5436 100644
--- 
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
+++ 
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -19,7 +19,7 @@ import sys
 import pytest
 import tvm
 from tvm import PoolInfoProperties, WorkspacePoolInfo
-from tvm.script import tir as T
+from tvm.script import tir as T, ir as I
 from tvm.target import Target
 from tvm.tir import stmt_functor
 from tvm.tir.usmp import utils as usmp_utils
@@ -67,6 +67,38 @@ def _assign_targets_to_primfuncs_irmodule(mod, target):
     return ret
 
 
+def _plan_and_convert(tir_mod, pools=None):
+    target = Target("c")
+
+    if pools is None:
+        pools = [
+            WorkspacePoolInfo(
+                "global_workspace",
+                [target],
+            )
+        ]
+
+    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+    tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, pools)
+    main_func = tir_mod["__tvm_main__"]
+    buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, 
tir_mod)
+    buffer_info_map = buffer_analysis.buffer_info_stmts
+
+    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    fusmp_algo_greedy_by_size = 
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    buffer_pool_allocations = fusmp_algo_greedy_by_size(
+        buffer_info_arr, buffer_analysis.memory_pressure
+    )
+    fassign_stmt_pool_allocations = 
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
+    pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, 
buffer_pool_allocations)
+    tir_mod_with_offsets = 
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
+        pool_allocations, emit_tvmscript_printable=True
+    )(tir_mod)
+
+    return tir_mod_with_offsets
+
+
 # fmt: off
 @tvm.script.ir_module
 class LinearStructure:
@@ -210,42 +242,24 @@ class LinearStructurePlanned:
 
 
 def test_mobilenet_subgraph():
-    target = Target("c")
-    fast_memory_pool = WorkspacePoolInfo(
-        "fast_memory",
-        [target],
-        PoolInfoProperties(size_hint_bytes=200704),
-    )
-    slow_memory_pool = WorkspacePoolInfo(
-        "slow_memory",
-        [target],
-    )
-    tir_mod = LinearStructure
-    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
-    tir_mod = assign_poolinfos_to_allocates_in_irmodule(
-        tir_mod, [fast_memory_pool, slow_memory_pool]
-    )
-    main_func = tir_mod["__tvm_main__"]
-    buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, 
tir_mod)
-    buffer_info_map = buffer_analysis.buffer_info_stmts
-
-    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
-    buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = 
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
-    buffer_pool_allocations = fusmp_algo_greedy_by_size(
-        buffer_info_arr, buffer_analysis.memory_pressure
-    )
-    fassign_stmt_pool_allocations = 
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
-    pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, 
buffer_pool_allocations)
-    tir_mod_with_offsets = 
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
-        pool_allocations, emit_tvmscript_printable=True
-    )(tir_mod)
+    before = LinearStructure
 
-    tir_mod_with_offsets_ref = LinearStructurePlanned
+    expected = LinearStructurePlanned
 
-    for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
-        actual_func = tir_mod_with_offsets[gv.name_hint]
-        tvm.ir.assert_structural_equal(actual_func, ref_func)
+    target = Target("c")
+    pools = [
+        WorkspacePoolInfo(
+            "fast_memory",
+            [target],
+            PoolInfoProperties(size_hint_bytes=200704),
+        ),
+        WorkspacePoolInfo(
+            "slow_memory",
+            [target],
+        ),
+    ]
+    after = _plan_and_convert(before, pools=pools)
+    tvm.ir.assert_structural_equal(after, expected)
 
 
 # fmt: off
@@ -500,35 +514,10 @@ class ResnetStructurePlanned:
 
 
 def test_resnet_subgraph():
-    target = Target("c")
-    global_workspace_pool = WorkspacePoolInfo(
-        "global_workspace",
-        [target],
-    )
-    tir_mod = ResnetStructure
-    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
-    tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, 
[global_workspace_pool])
-    main_func = tir_mod["__tvm_main__"]
-    buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, 
tir_mod)
-    buffer_info_map = buffer_analysis.buffer_info_stmts
-
-    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
-    buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = 
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
-    buffer_pool_allocations = fusmp_algo_greedy_by_size(
-        buffer_info_arr, buffer_analysis.memory_pressure
-    )
-    fassign_stmt_pool_allocations = 
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
-    pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, 
buffer_pool_allocations)
-    tir_mod_with_offsets = 
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
-        pool_allocations, emit_tvmscript_printable=True
-    )(tir_mod)
-
-    tir_mod_with_offsets_ref = ResnetStructurePlanned
-
-    for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
-        actual_func = tir_mod_with_offsets[gv.name_hint]
-        tvm.ir.assert_structural_equal(actual_func, ref_func)
+    before = ResnetStructure
+    expected = ResnetStructurePlanned
+    after = _plan_and_convert(before)
+    tvm.ir.assert_structural_equal(after, expected)
 
 
 @tvm.script.ir_module
@@ -591,36 +580,116 @@ class TensorIntrinStructurePlanned:
 
 
 def test_tensor_intrin():
-    target = Target("c")
-    global_workspace_pool = WorkspacePoolInfo(
-        "global_workspace",
-        [target],
-    )
-
-    tir_mod = TensorIntrinStructure
-    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
-    tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, 
[global_workspace_pool])
-    main_func = tir_mod["__tvm_main__"]
-    buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, 
tir_mod)
-    buffer_info_map = buffer_analysis.buffer_info_stmts
-
-    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
-    buffer_info_arr = fcreate_array_bi(buffer_info_map)
-    fusmp_algo_greedy_by_size = 
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
-    buffer_pool_allocations = fusmp_algo_greedy_by_size(
-        buffer_info_arr, buffer_analysis.memory_pressure
-    )
-    fassign_stmt_pool_allocations = 
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
-    pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, 
buffer_pool_allocations)
-    tir_mod_with_offsets = 
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
-        pool_allocations, emit_tvmscript_printable=True
-    )(tir_mod)
-
+    before = TensorIntrinStructure
+    after = _plan_and_convert(before)
     expected = TensorIntrinStructurePlanned
-
-    for gv, ref_func in expected.functions.items():
-        actual_func = tir_mod_with_offsets[gv.name_hint]
-        tvm.ir.assert_structural_equal(actual_func, ref_func)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+class TestMergeAllocations(tvm.testing.CompareBeforeAfter):
+    def transform(self):
+        return _plan_and_convert
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, 
"int8")):
+                B = T.allocate([256], "int8")
+                T.call_extern("subroutine", A.data, B, dtype="int32")
+                C = T.allocate([256], "int8")
+                T.call_extern("subroutine", B, C, dtype="int32")
+                T.call_extern("subroutine", C, D.data, dtype="int32")
+
+            @T.prim_func
+            def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
+                for i in range(256):
+                    B[i] = A[i]
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def __tvm_main__(
+                A: T.Buffer(256, "int8"),
+                D: T.Buffer(256, "int8"),
+                workspace_var: T.handle("uint8"),
+            ):
+                workspace = T.match_buffer(workspace_var, 512, "uint8", 
strides=[1], align=16)
+                B: T.handle("int8") = T.address_of(workspace[256])
+                T.call_extern("subroutine", A.data, B, workspace.data, 
dtype="int32")
+                C: T.handle("int8") = T.address_of(workspace[0])
+                T.call_extern("subroutine", B, C, workspace.data, 
dtype="int32")
+                T.call_extern("subroutine", C, D.data, workspace.data, 
dtype="int32")
+
+            @T.prim_func
+            def subroutine(
+                A: T.Buffer(256, "int8"),
+                B: T.Buffer(256, "int8"),
+                workspace_var: T.handle("uint8"),
+            ):
+                workspace = T.match_buffer(workspace_var, 512, "uint8", 
strides=[1], align=16)
+                for i in range(256):
+                    B[i] = A[i]
+
+        return mod
+
+
+class TestMergeAllocationsWithDeclBuffer(tvm.testing.CompareBeforeAfter):
+    """Like TestMergeAllocations, but using T.decl_buffer"""
+
+    def transform(self):
+        return _plan_and_convert
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256, 
"int8")):
+                B = T.decl_buffer([256], "int8")
+                T.call_extern("subroutine", A.data, B.data, dtype="int32")
+                C = T.decl_buffer([256], "int8")
+                T.call_extern("subroutine", B.data, C.data, dtype="int32")
+                T.call_extern("subroutine", C.data, D.data, dtype="int32")
+
+            @T.prim_func
+            def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
+                for i in range(256):
+                    B[i] = A[i]
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def __tvm_main__(
+                A: T.Buffer(256, "int8"),
+                D: T.Buffer(256, "int8"),
+                workspace_var: T.handle("uint8"),
+            ):
+                workspace = T.match_buffer(workspace_var, 512, "uint8", 
strides=[1], align=16)
+                B_data: T.handle("int8") = T.address_of(workspace[256])
+                B = T.decl_buffer(256, "int8", data=B_data)
+                T.call_extern("subroutine", A.data, B.data, workspace.data, 
dtype="int32")
+                C_data: T.handle("int8") = T.address_of(workspace[0])
+                C = T.decl_buffer(256, "int8", data=C_data)
+                T.call_extern("subroutine", B.data, C.data, workspace.data, 
dtype="int32")
+                T.call_extern("subroutine", C.data, D.data, workspace.data, 
dtype="int32")
+
+            @T.prim_func
+            def subroutine(
+                A: T.Buffer(256, "int8"),
+                B: T.Buffer(256, "int8"),
+                workspace_var: T.handle("uint8"),
+            ):
+                workspace = T.match_buffer(workspace_var, 512, "uint8", 
strides=[1], align=16)
+                for i in range(256):
+                    B[i] = A[i]
+
+        return mod
 
 
 if __name__ == "__main__":

Reply via email to