This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f5ab3f05eb [TIR] [Analysis] Calculate allocated memory at module level
(#14711)
f5ab3f05eb is described below
commit f5ab3f05eb3190b836f41bbeb975258232010def
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Tue Apr 25 09:18:42 2023 +0530
[TIR] [Analysis] Calculate allocated memory at module level (#14711)
* [TIR] [Analysis] Calculate allocated memory at module level
This patch modifies the existing analysis pass
`tir.calculate_allocated_bytes` to accept an IRModule as an argument and
return allocated bytes for all prim_funcs in the IRModule.
* Fix docstring and modify python API to be consistent with c++
---
include/tvm/tir/analysis.h | 12 +++-
python/tvm/tir/analysis/analysis.py | 21 +++++--
src/tir/analysis/calculate_allocated_memory.cc | 36 ++++++++---
...test_tir_analysis_calculate_allocated_memory.py | 69 +++++++++++++++++-----
4 files changed, 109 insertions(+), 29 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 4ed164e5ad..3b5959e781 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -266,8 +266,18 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc&
func,
/*!
* \brief Calculate the allocated memory per scope in bytes needed inside the
TIR PrimFunc
* \param func The TIR PrimFunc for which the the allocated memory size to be
calculated
+ * \return Allocated memory size per scope in bytes inside the PrimFunc
returned as a Map with
+ * key "main" and a Map of allocated sizes as values.
*/
-TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc&
func);
+TVM_DLL tvm::Map<String, tvm::Map<String, Integer>>
CalculateAllocatedBytes(const PrimFunc& func);
+
+/*!
+ * \brief Calculate the allocated memory per scope in bytes for each function
inside the module
+ * \param mod The IRModule for which the the allocated memory size has to be
calculated
+ * \return Allocated memory size per scope in bytes for each function in the
IRModule returned as a
+ Map with function names as keys and a Map of allocated sizes as
values.
+ */
+TVM_DLL tvm::Map<String, tvm::Map<String, Integer>>
CalculateAllocatedBytes(const IRModule& mod);
/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including
both high-level
diff --git a/python/tvm/tir/analysis/analysis.py
b/python/tvm/tir/analysis/analysis.py
index 5feb630e48..387ea04980 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -201,20 +201,29 @@ def calculate_constant_bytes(func: PrimFunc,
constant_byte_alignment: int) -> in
return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) #
type: ignore
-def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:
+def calculate_allocated_bytes(
+ func_or_mod: Union[PrimFunc, IRModule]
+) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]:
"""Calculate allocated memory per memory scope required by TIR PrimFuncs.
Parameters
----------
- func: tvm.tir.PrimFunc
- The function to be detected.
+ func_or_mod: Union[PrimFunc, IRModule]
+ The function or module to be detected. If a module is passed, allocated
+ memory is calcualted for all PrimFuncs inside the module
Returns
-------
- result : Dict[String, int]
- Allocated memory size per scope in bytes.
+ result : Union[Dict[str, int], Dict[str, Dict[str, int]]]
+ Allocated memory size per scope in bytes for each function in the
IRModule returned as a
+ dict with function names as keys and a dict of allocated sizes as
values. If a single
+ PrimFunc is passed, the function name is returned as "main"
"""
- return _ffi_api.calculate_allocated_bytes(func) # type: ignore
+ if not isinstance(func_or_mod, (PrimFunc, IRModule)):
+ raise TypeError(
+ f"Expected argument to be PrimFunc or IRModule, but received
{type(func_or_mod)}"
+ )
+ return _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore
def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
diff --git a/src/tir/analysis/calculate_allocated_memory.cc
b/src/tir/analysis/calculate_allocated_memory.cc
index ffdfc1f801..8680f57e4c 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -79,16 +79,38 @@ void AllocationCalculator<T>::VisitStmt_(const T* op) {
_current_size[storage_scope] -= size;
}
-tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
- return AllocationCalculator<AllocateNode>()(func);
+tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const
PrimFunc& func) {
+ tvm::Map<String, tvm::Map<String, Integer> > results;
+ results.Set("main", AllocationCalculator<AllocateNode>()(func));
+ return results;
}
-TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc
func) {
- return CalculateAllocatedBytes(func);
-});
+tvm::Map<String, tvm::Map<String, Integer> > CalculateAllocatedBytes(const
IRModule& mod) {
+ tvm::Map<String, tvm::Map<String, Integer> > results;
+ for (const auto& kv : mod->functions) {
+ if (auto prim_func = kv.second.as<tir::PrimFunc>()) {
+ String func_name = kv.first->name_hint;
+ results.Set(func_name,
AllocationCalculator<AllocateNode>()(prim_func.value()));
+ }
+ }
+ return results;
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes")
+ .set_body_typed([](ObjectRef obj) -> tvm::Map<String, tvm::Map<String,
Integer> > {
+ if (auto func = obj.as<PrimFunc>()) {
+ return CalculateAllocatedBytes(func.value());
+ } else if (auto mod = obj.as<IRModule>()) {
+ return CalculateAllocatedBytes(mod.value());
+ } else {
+ LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or
IRModule, but gets: "
+ << obj->GetTypeKey();
+ throw;
+ }
+ });
bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
- auto sizes = CalculateAllocatedBytes(func);
+ auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
return false;
@@ -121,7 +143,7 @@ Pass VerifyVTCMLimit(Optional<Target> default_target) {
}
if (limit.has_value() && limit.value() > 0) {
- auto sizes = CalculateAllocatedBytes(func);
+ auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (vtcm_allocated.IntValue() > limit.value()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation
limit has been exceeded "
diff --git
a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
index 2311bfbbef..cb3a663c03 100644
--- a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
+++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
@@ -14,32 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import pytest
import tvm
from tvm import tir
from tvm.script import tir as T
+# fmt: off
+# pylint:
disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
[email protected]_func
-def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
- for i in T.serial(128):
- with T.block("C"):
- c[i] = a[i] * T.int8(2)
[email protected]_module
+class Module:
+ @T.prim_func
+ def scale_by_two(a: T.Buffer((128,), "int8"), c: T.Buffer((128,), "int8")):
+ for i in T.serial(128):
+ with T.block("C"):
+ c[i] = a[i] * T.int8(2)
[email protected]_func
-def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,),
"int8")):
- B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
- for i in T.serial(128):
- with T.block("B"):
- B[i] = a[i] * T.int8(2)
- for i in T.serial(128):
- with T.block("C"):
- c[i] = B[i] * T.int8(3)
+ @T.prim_func
+ def scale_by_two_three(a: T.Buffer((128,), "int8"), c: T.Buffer((128,),
"int8")):
+ B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
+ for i in T.serial(128):
+ with T.block("B"):
+ B[i] = a[i] * T.int8(2)
+ for i in T.serial(128):
+ with T.block("C"):
+ c[i] = B[i] * T.int8(3)
+
+# pylint:
enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
+# fmt: on
[email protected]("primFunc,size", [(scale_by_two, 128),
(scale_by_two_three, 256)])
[email protected](
+ "primFunc,size", [(Module["scale_by_two"], 128),
(Module["scale_by_two_three"], 256)]
+)
def test_scale_by(primFunc, size):
"""Test calculate allocated bytes per scope"""
mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main"))
@@ -53,6 +63,8 @@ def test_scale_by(primFunc, size):
mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
+ assert "main" in sizes, 'Calls with PrimFunc is expected to return with
function key as "main"'
+ sizes = sizes["main"]
assert sizes.get("global.vtcm", 0) == size
@@ -94,8 +106,35 @@ def test_matmul_mix_scope(scope, size):
mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
+ assert "main" in sizes, 'Calls with PrimFunc is expected to return with
function key as "main"'
+ sizes = sizes["main"]
assert sizes.get(scope, 0) == size
+def test_full_mod_calculator():
+ def apply_schedule(sch, func_name):
+ sch.work_on(func_name)
+ block_c = sch.get_block("C")
+ sch.cache_read(block_c, 0, storage_scope="global.vtcm")
+
+ sch = tvm.tir.Schedule(Module, debug_mask="all")
+ apply_schedule(sch, "scale_by_two")
+ apply_schedule(sch, "scale_by_two_three")
+ mod = tvm.tir.transform.ConvertBlocksToOpaque()(sch.mod)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+ sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)
+ assert "scale_by_two" in sizes, "Values for scale_by_two not found"
+ scale_by_two_sizes = sizes["scale_by_two"]
+ assert (
+ "global.vtcm" in scale_by_two_sizes
+ ), "Expected global.vtcm allocation to be calculated scale_by_two"
+ assert scale_by_two_sizes["global.vtcm"] == 128, "Expected the calculated
size to be 128"
+ scale_by_two_three_sizes = sizes["scale_by_two_three"]
+ assert (
+ "global.vtcm" in scale_by_two_three_sizes
+ ), "Expected global.vtcm allocation to be calculated scale_by_two_three"
+ assert scale_by_two_three_sizes["global.vtcm"] == 256, "Expected the
calculated size to be 256"
+
+
if __name__ == "__main__":
tvm.testing.main()