This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab02979a86 [AOT] Correctly calculate workspace for vector types
(#17077)
ab02979a86 is described below
commit ab02979a86a44e0a4093760611c7f0ec6c6a86f7
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Jun 11 15:06:56 2024 +0100
[AOT] Correctly calculate workspace for vector types (#17077)
When calculating the size of the workspace for a given prim func, the
lanes of the data type was not being considered, meaning sizes
calculated for dtypes such as "float32x4" were smaller than what they
should be. This commit also considers lanes in the calculation.
---
src/tir/usmp/utils.cc | 6 +++++-
.../test_tir_analysis_calculate_workspace.py | 20 ++++++++++++++++++--
2 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc
index 88a6496859..d640e9fa07 100644
--- a/src/tir/usmp/utils.cc
+++ b/src/tir/usmp/utils.cc
@@ -181,7 +181,11 @@ Map<String, PoolAllocation> GetIOPoolAllocations(
}
static Integer CalculateExtentsSize(const DataType& dtype, const
Array<PrimExpr>& extents) {
- size_t element_size_bytes = dtype.bytes();
+ if (dtype.is_scalable_vector()) {
+ // We cannot statically calculate workspace for scalable types
+ return Integer();
+ }
+ size_t element_size_bytes = dtype.bytes() * dtype.lanes();
size_t num_elements = 1;
for (const auto& ext : extents) {
if (ext->IsInstance<IntImmNode>()) {
diff --git a/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py
b/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py
index 12c892a04b..29bfc58458 100644
--- a/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py
+++ b/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py
@@ -91,6 +91,18 @@ def primfunc_local_allocates(placeholder_162: T.handle,
placeholder_163: T.handl
# fmt: on
[email protected]_func
+def prim_func_decl_vector_type(a: T.handle, b: T.handle):
+ T.func_attr({"tir.noalias": True})
+ A = T.match_buffer(a, (4,), "float32x4")
+ B = T.match_buffer(b, (4,), "float32x4")
+ C = T.decl_buffer((4,), "float32x4")
+ for i in range(3):
+ with T.block("block"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + C[vi]
+
+
@pytest.mark.parametrize("alignment,size,consts", [(1, 663552, 0), (10,
663560, 0)])
def test_global_allocates(alignment, size, consts):
primfunc = primfunc_global_allocates
@@ -105,6 +117,10 @@ def test_local_allocates(alignment, size, consts):
assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, alignment) ==
size
+def test_vector_type():
+ primfunc = prim_func_decl_vector_type
+ assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, 1) == 64
+
+
if __name__ == "__main__":
- test_global_allocates()
- test_local_allocates()
+ tvm.testing.main()