This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b48fcaba22 [TIR][Hexagon] Use the "target" value in T.func_attr for
VTCM limit (#14567)
b48fcaba22 is described below
commit b48fcaba227c6d455c30bec2216183fed9853677
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Apr 13 14:16:30 2023 -0500
[TIR][Hexagon] Use the "target" value in T.func_attr for VTCM limit (#14567)
* [TIR][Hexagon] Use the "target" value in T.func_attr for VTCM limit
For the VerifyVTCMLimit, read directly from the function attribute, if
the function has already been annotated with the target.
* Retain passing of target to VerifyVTCMLimit
---
include/tvm/tir/analysis.h | 6 ++--
src/auto_scheduler/feature.cc | 4 +--
src/driver/driver_api.cc | 11 +-------
src/tir/analysis/calculate_allocated_memory.cc | 39 +++++++++++++++++++-------
4 files changed, 35 insertions(+), 25 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 5bac25faa5..4ed164e5ad 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -26,6 +26,7 @@
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
+#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
@@ -348,12 +349,13 @@ TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr>
constraints);
/*!
* \brief Pass to checks if the size of the allocated vtcm memory satisfies
the limit
*
- * \param limit The limit to check.
+ * \param target The target whose VTCM limit should be used for any
+ * functions not already annotated with `tvm::attr::kTarget`.
*
* \returns The pass.
* \sa tvm::tir::CalculateAllocatedBytes
*/
-TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
+TVM_DLL Pass VerifyVTCMLimit(Optional<Target> target = NullOpt);
/*!
* \brief Statically check TIR code for out of bounds array access.
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 884215c24a..65cc13eb61 100644
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1408,9 +1408,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask&
task, const State& state, i
}
if (IsHexagonTask(task)) {
Target target = task->target;
- const auto vtcm_capacity =
target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
- const auto& optimize =
-
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
+ const auto& optimize =
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(target)});
optimize(mod);
}
const auto& optimize =
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 1962b9ab3b..486b40c994 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -544,22 +544,13 @@ runtime::Module build(const IRModule& funcs, const
Target& target_arg,
return TIRToRuntime(inputs, target_host);
}
-int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx)
{
- if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
- if (target.defined() && target->kind->name == "hexagon") {
- auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
- if (value > 0) return value;
- }
- return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity",
Integer(0)).value()->value;
-}
-
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target
target) {
transform::PassContext pass_ctx = transform::PassContext::Current();
Array<Pass> mixed_pass_list;
// VerifyVTCMLimit must occur before LowerVtcmAlloc
-
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target,
pass_ctx)));
+ mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory
allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
diff --git a/src/tir/analysis/calculate_allocated_memory.cc
b/src/tir/analysis/calculate_allocated_memory.cc
index 95fd7f134e..ffdfc1f801 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -96,20 +96,39 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
return true;
}
+int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx)
{
+ if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
+ if (target.defined() && target->kind->name == "hexagon") {
+ auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
+ if (value > 0) return value;
+ }
+ return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity",
Integer(0)).value()->value;
+}
+
namespace transform {
-Pass VerifyVTCMLimit(const Integer& limit) {
+Pass VerifyVTCMLimit(Optional<Target> default_target) {
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
- if (auto func = kv.second.as<PrimFunc>()) {
- auto sizes = CalculateAllocatedBytes(func.value());
- const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
- if (limit.IntValue() > 0 && vtcm_allocated.IntValue() >
limit.IntValue()) {
- LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit
has been "
- "exceeded(allocated: "
- << vtcm_allocated << ", limit: " << limit << ").\n"
- << "In function\n"
- << func;
+ if (auto opt = kv.second.as<PrimFunc>()) {
+ auto func = opt.value();
+
+ std::optional<int64_t> limit = std::nullopt;
+ if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
+ limit = GetVTCMCapacity(func_target.value(), ctx);
+ } else if (default_target) {
+ limit = GetVTCMCapacity(default_target.value(), ctx);
+ }
+
+ if (limit.has_value() && limit.value() > 0) {
+ auto sizes = CalculateAllocatedBytes(func);
+ const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
+ if (vtcm_allocated.IntValue() > limit.value()) {
+ LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation
limit has been exceeded "
+ << "(allocated: " << vtcm_allocated << ", limit: " <<
limit.value() << ").\n"
+ << "In function\n"
+ << func;
+ }
}
}
}