This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fa8a9f7aaa [TIR][USMP] Preserve DeclBuffer in
PoolAllocationToOffsetConverter (#15044)
fa8a9f7aaa is described below
commit fa8a9f7aaa752cfb71665d52d6c03de1050d81cb
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Jun 16 18:17:03 2023 -0400
[TIR][USMP] Preserve DeclBuffer in PoolAllocationToOffsetConverter (#15044)
Previously, `PoolAllocationToOffsetConverter` did not remap buffer
objects occurring in `DeclBuffer` nodes. This commit updates
`PoolAllocationToOffsetConverter` to handle `DeclBuffer` nodes. This
is a subset of changes, being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
---
.../convert_pool_allocations_to_offsets.cc | 11 +
...ransform_convert_pool_allocations_to_offsets.py | 255 +++++++++++++--------
2 files changed, 173 insertions(+), 93 deletions(-)
diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
index 439e264338..45d060567c 100644
--- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
+++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
@@ -99,6 +99,7 @@ class PoolAllocationToOffsetConverter : public
StmtExprMutator {
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
+ Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var&
buffer_var, const Stmt& body);
@@ -386,6 +387,16 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const
BufferStoreNode* op) {
return std::move(store);
}
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const DeclBufferNode* op) {
+ auto decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+
+ Buffer remapped = GetRemappedBuffer(decl->buffer);
+ if (!op->buffer.same_as(remapped)) {
+ decl.CopyOnWrite()->buffer = remapped;
+ }
+ return std::move(decl);
+}
+
PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op)
{
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
diff --git
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
index d0403fcae9..03929c5436 100644
---
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
+++
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -19,7 +19,7 @@ import sys
import pytest
import tvm
from tvm import PoolInfoProperties, WorkspacePoolInfo
-from tvm.script import tir as T
+from tvm.script import tir as T, ir as I
from tvm.target import Target
from tvm.tir import stmt_functor
from tvm.tir.usmp import utils as usmp_utils
@@ -67,6 +67,38 @@ def _assign_targets_to_primfuncs_irmodule(mod, target):
return ret
+def _plan_and_convert(tir_mod, pools=None):
+ target = Target("c")
+
+ if pools is None:
+ pools = [
+ WorkspacePoolInfo(
+ "global_workspace",
+ [target],
+ )
+ ]
+
+ tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+ tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, pools)
+ main_func = tir_mod["__tvm_main__"]
+ buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func,
tir_mod)
+ buffer_info_map = buffer_analysis.buffer_info_stmts
+
+ fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+ buffer_info_arr = fcreate_array_bi(buffer_info_map)
+ fusmp_algo_greedy_by_size =
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+ buffer_pool_allocations = fusmp_algo_greedy_by_size(
+ buffer_info_arr, buffer_analysis.memory_pressure
+ )
+ fassign_stmt_pool_allocations =
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
+ pool_allocations = fassign_stmt_pool_allocations(buffer_info_map,
buffer_pool_allocations)
+ tir_mod_with_offsets =
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
+ pool_allocations, emit_tvmscript_printable=True
+ )(tir_mod)
+
+ return tir_mod_with_offsets
+
+
# fmt: off
@tvm.script.ir_module
class LinearStructure:
@@ -210,42 +242,24 @@ class LinearStructurePlanned:
def test_mobilenet_subgraph():
- target = Target("c")
- fast_memory_pool = WorkspacePoolInfo(
- "fast_memory",
- [target],
- PoolInfoProperties(size_hint_bytes=200704),
- )
- slow_memory_pool = WorkspacePoolInfo(
- "slow_memory",
- [target],
- )
- tir_mod = LinearStructure
- tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
- tir_mod = assign_poolinfos_to_allocates_in_irmodule(
- tir_mod, [fast_memory_pool, slow_memory_pool]
- )
- main_func = tir_mod["__tvm_main__"]
- buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func,
tir_mod)
- buffer_info_map = buffer_analysis.buffer_info_stmts
-
- fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
- buffer_info_arr = fcreate_array_bi(buffer_info_map)
- fusmp_algo_greedy_by_size =
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
- buffer_pool_allocations = fusmp_algo_greedy_by_size(
- buffer_info_arr, buffer_analysis.memory_pressure
- )
- fassign_stmt_pool_allocations =
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
- pool_allocations = fassign_stmt_pool_allocations(buffer_info_map,
buffer_pool_allocations)
- tir_mod_with_offsets =
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
- pool_allocations, emit_tvmscript_printable=True
- )(tir_mod)
+ before = LinearStructure
- tir_mod_with_offsets_ref = LinearStructurePlanned
+ expected = LinearStructurePlanned
- for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
- actual_func = tir_mod_with_offsets[gv.name_hint]
- tvm.ir.assert_structural_equal(actual_func, ref_func)
+ target = Target("c")
+ pools = [
+ WorkspacePoolInfo(
+ "fast_memory",
+ [target],
+ PoolInfoProperties(size_hint_bytes=200704),
+ ),
+ WorkspacePoolInfo(
+ "slow_memory",
+ [target],
+ ),
+ ]
+ after = _plan_and_convert(before, pools=pools)
+ tvm.ir.assert_structural_equal(after, expected)
# fmt: off
@@ -500,35 +514,10 @@ class ResnetStructurePlanned:
def test_resnet_subgraph():
- target = Target("c")
- global_workspace_pool = WorkspacePoolInfo(
- "global_workspace",
- [target],
- )
- tir_mod = ResnetStructure
- tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
- tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod,
[global_workspace_pool])
- main_func = tir_mod["__tvm_main__"]
- buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func,
tir_mod)
- buffer_info_map = buffer_analysis.buffer_info_stmts
-
- fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
- buffer_info_arr = fcreate_array_bi(buffer_info_map)
- fusmp_algo_greedy_by_size =
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
- buffer_pool_allocations = fusmp_algo_greedy_by_size(
- buffer_info_arr, buffer_analysis.memory_pressure
- )
- fassign_stmt_pool_allocations =
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
- pool_allocations = fassign_stmt_pool_allocations(buffer_info_map,
buffer_pool_allocations)
- tir_mod_with_offsets =
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
- pool_allocations, emit_tvmscript_printable=True
- )(tir_mod)
-
- tir_mod_with_offsets_ref = ResnetStructurePlanned
-
- for gv, ref_func in tir_mod_with_offsets_ref.functions.items():
- actual_func = tir_mod_with_offsets[gv.name_hint]
- tvm.ir.assert_structural_equal(actual_func, ref_func)
+ before = ResnetStructure
+ expected = ResnetStructurePlanned
+ after = _plan_and_convert(before)
+ tvm.ir.assert_structural_equal(after, expected)
@tvm.script.ir_module
@@ -591,36 +580,116 @@ class TensorIntrinStructurePlanned:
def test_tensor_intrin():
- target = Target("c")
- global_workspace_pool = WorkspacePoolInfo(
- "global_workspace",
- [target],
- )
-
- tir_mod = TensorIntrinStructure
- tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
- tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod,
[global_workspace_pool])
- main_func = tir_mod["__tvm_main__"]
- buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func,
tir_mod)
- buffer_info_map = buffer_analysis.buffer_info_stmts
-
- fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
- buffer_info_arr = fcreate_array_bi(buffer_info_map)
- fusmp_algo_greedy_by_size =
tvm.get_global_func("tir.usmp.algo.greedy_by_size")
- buffer_pool_allocations = fusmp_algo_greedy_by_size(
- buffer_info_arr, buffer_analysis.memory_pressure
- )
- fassign_stmt_pool_allocations =
tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
- pool_allocations = fassign_stmt_pool_allocations(buffer_info_map,
buffer_pool_allocations)
- tir_mod_with_offsets =
tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
- pool_allocations, emit_tvmscript_printable=True
- )(tir_mod)
-
+ before = TensorIntrinStructure
+ after = _plan_and_convert(before)
expected = TensorIntrinStructurePlanned
-
- for gv, ref_func in expected.functions.items():
- actual_func = tir_mod_with_offsets[gv.name_hint]
- tvm.ir.assert_structural_equal(actual_func, ref_func)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+class TestMergeAllocations(tvm.testing.CompareBeforeAfter):
+ def transform(self):
+ return _plan_and_convert
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256,
"int8")):
+ B = T.allocate([256], "int8")
+ T.call_extern("subroutine", A.data, B, dtype="int32")
+ C = T.allocate([256], "int8")
+ T.call_extern("subroutine", B, C, dtype="int32")
+ T.call_extern("subroutine", C, D.data, dtype="int32")
+
+ @T.prim_func
+ def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
+ for i in range(256):
+ B[i] = A[i]
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def __tvm_main__(
+ A: T.Buffer(256, "int8"),
+ D: T.Buffer(256, "int8"),
+ workspace_var: T.handle("uint8"),
+ ):
+ workspace = T.match_buffer(workspace_var, 512, "uint8",
strides=[1], align=16)
+ B: T.handle("int8") = T.address_of(workspace[256])
+ T.call_extern("subroutine", A.data, B, workspace.data,
dtype="int32")
+ C: T.handle("int8") = T.address_of(workspace[0])
+ T.call_extern("subroutine", B, C, workspace.data,
dtype="int32")
+ T.call_extern("subroutine", C, D.data, workspace.data,
dtype="int32")
+
+ @T.prim_func
+ def subroutine(
+ A: T.Buffer(256, "int8"),
+ B: T.Buffer(256, "int8"),
+ workspace_var: T.handle("uint8"),
+ ):
+ workspace = T.match_buffer(workspace_var, 512, "uint8",
strides=[1], align=16)
+ for i in range(256):
+ B[i] = A[i]
+
+ return mod
+
+
+class TestMergeAllocationsWithDeclBuffer(tvm.testing.CompareBeforeAfter):
+ """Like TestMergeAllocations, but using T.decl_buffer"""
+
+ def transform(self):
+ return _plan_and_convert
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def __tvm_main__(A: T.Buffer(256, "int8"), D: T.Buffer(256,
"int8")):
+ B = T.decl_buffer([256], "int8")
+ T.call_extern("subroutine", A.data, B.data, dtype="int32")
+ C = T.decl_buffer([256], "int8")
+ T.call_extern("subroutine", B.data, C.data, dtype="int32")
+ T.call_extern("subroutine", C.data, D.data, dtype="int32")
+
+ @T.prim_func
+ def subroutine(A: T.Buffer(256, "int8"), B: T.Buffer(256, "int8")):
+ for i in range(256):
+ B[i] = A[i]
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def __tvm_main__(
+ A: T.Buffer(256, "int8"),
+ D: T.Buffer(256, "int8"),
+ workspace_var: T.handle("uint8"),
+ ):
+ workspace = T.match_buffer(workspace_var, 512, "uint8",
strides=[1], align=16)
+ B_data: T.handle("int8") = T.address_of(workspace[256])
+ B = T.decl_buffer(256, "int8", data=B_data)
+ T.call_extern("subroutine", A.data, B.data, workspace.data,
dtype="int32")
+ C_data: T.handle("int8") = T.address_of(workspace[0])
+ C = T.decl_buffer(256, "int8", data=C_data)
+ T.call_extern("subroutine", B.data, C.data, workspace.data,
dtype="int32")
+ T.call_extern("subroutine", C.data, D.data, workspace.data,
dtype="int32")
+
+ @T.prim_func
+ def subroutine(
+ A: T.Buffer(256, "int8"),
+ B: T.Buffer(256, "int8"),
+ workspace_var: T.handle("uint8"),
+ ):
+ workspace = T.match_buffer(workspace_var, 512, "uint8",
strides=[1], align=16)
+ for i in range(256):
+ B[i] = A[i]
+
+ return mod
if __name__ == "__main__":