This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ce0ac662fe [Relax][PyTorch] Add lower bound support for range
constraints (#18447)
ce0ac662fe is described below
commit ce0ac662fe2bb26f85b43c1bcfe2deb0620aada7
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 14 19:51:49 2025 +0800
[Relax][PyTorch] Add lower bound support for range constraints (#18447)
Add lower bound support for range constraints
---
include/tvm/relax/transform.h | 19 ++++++-----
.../frontend/torch/exported_program_translator.py | 6 ++--
src/relax/transform/adjust_matmul_order.cc | 32 ++++++++++++++----
src/relax/transform/static_plan_block_memory.cc | 38 ++++++++++++++--------
.../relax/test_frontend_from_exported_program.py | 2 +-
.../test_transform_static_plan_block_memory.py | 12 +++++++
6 files changed, 77 insertions(+), 32 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index a8ccc4076b..58cf7421b5 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -125,18 +125,19 @@ TVM_DLL Pass RewriteDataflowReshape();
* The pass will reuse allocated memory to its best effort, in order to
* reduce the total amount of allocated memory size.
*
- * The pass "supports" dynamic shape in the way of TIR variable upper bound
- * annotation. We can optionally annotate the attribute "tir_var_upper_bound"
- * to Relax functions. The attribute value is a dict from strings to integers,
- * denoting the name of TIR variables to the upper bound values of the TIR
vars.
- * Note: The annotated upper bound attribute only applies to TIR vars in the
+ * The pass "supports" dynamic shape in the way of TIR variable bound
+ * annotations. We can optionally annotate the attributes "tir_var_upper_bound"
+ * and "tir_var_lower_bound" to Relax functions. The attribute values are dicts
+ * from strings to integers, denoting the name of TIR variables to the bound
+ * values of the TIR vars.
+ * Note: The annotated bound attributes only apply to TIR vars in the
* function signature for clarity.
*
* For example, we can annotate a Relax function with
- * `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
- * It means the maximum value of variable that names "n" in the function
- * signature will have upper bound 1024. And we will use 1024 as its value
- * during memory planning.
+ * `R.func_attr({"tir_var_lower_bound": {"n": 1}, "tir_var_upper_bound":
{"n": 1024}})`.
+ * It means the variable that names "n" in the function signature will have
+ * range [1, 1024]. And we will use these bounds during memory planning.
+ * If lower bound is not specified, it defaults to 0.
*
* \return The pass.
*/
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5cddf24a89..431a1444d1 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1181,10 +1181,12 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if range_constraints:
if func_attrs is None:
func_attrs = {}
- tir_var_upper_bound = {
+ func_attrs["tir_var_lower_bound"] = {
+ var_name: lower for var_name, (lower, _) in
range_constraints.items()
+ }
+ func_attrs["tir_var_upper_bound"] = {
var_name: upper for var_name, (_, upper) in
range_constraints.items()
}
- func_attrs["tir_var_upper_bound"] = tir_var_upper_bound
nodes: List[fx.Node] = exported_program.graph.nodes
diff --git a/src/relax/transform/adjust_matmul_order.cc
b/src/relax/transform/adjust_matmul_order.cc
index 98fe57e11c..8892720191 100644
--- a/src/relax/transform/adjust_matmul_order.cc
+++ b/src/relax/transform/adjust_matmul_order.cc
@@ -73,19 +73,37 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr,
ffi::Map<DFPattern, Expr>)>>
pat_permuted_matmul_on_rhs;
PrimExpr symbolic_var_constraints = Bool(true);
- if (auto upper_bounds = func->GetAttr<ffi::Map<ffi::String,
Any>>("tir_var_upper_bound")) {
+ auto upper_bounds = func->GetAttr<ffi::Map<ffi::String,
Any>>("tir_var_upper_bound");
+ auto lower_bounds = func->GetAttr<ffi::Map<ffi::String,
Any>>("tir_var_lower_bound");
+
+ if (upper_bounds || lower_bounds) {
ffi::Map<ffi::String, tir::Var> name_lookup;
for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) {
name_lookup.Set(tir_var->name_hint, tir_var);
symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var);
}
- for (const auto& [key, obj_bound] : upper_bounds.value()) {
- auto tir_var_name = Downcast<ffi::String>(key);
- if (auto opt_var = name_lookup.Get(tir_var_name)) {
- auto var = opt_var.value();
- auto expr_bound = Downcast<PrimExpr>(obj_bound);
- symbolic_var_constraints = symbolic_var_constraints && (var <
expr_bound);
+ // Add lower bound constraints
+ if (lower_bounds) {
+ for (const auto& [key, obj_bound] : lower_bounds.value()) {
+ auto tir_var_name = Downcast<ffi::String>(key);
+ if (auto opt_var = name_lookup.Get(tir_var_name)) {
+ auto var = opt_var.value();
+ auto expr_bound = Downcast<PrimExpr>(obj_bound);
+ symbolic_var_constraints = symbolic_var_constraints && (expr_bound
<= var);
+ }
+ }
+ }
+
+ // Add upper bound constraints
+ if (upper_bounds) {
+ for (const auto& [key, obj_bound] : upper_bounds.value()) {
+ auto tir_var_name = Downcast<ffi::String>(key);
+ if (auto opt_var = name_lookup.Get(tir_var_name)) {
+ auto var = opt_var.value();
+ auto expr_bound = Downcast<PrimExpr>(obj_bound);
+ symbolic_var_constraints = symbolic_var_constraints && (var <
expr_bound);
+ }
}
}
}
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index 85076206ae..fc3c2259ff 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -365,40 +365,52 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
};
/*!
- * \brief Set the upper bound of the TIR variables that appear in
+ * \brief Set the range constraints of the TIR variables that appear in
* the input function signature in the analyzer.
* \param func The function to be analyzed.
* \param ana The analyzer which contains the TIR var upper bounds.
* \param dom_map The domain map of the TIR variables.
*/
-void SetTIRVarUpperBound(Function func, arith::Analyzer* ana,
- ffi::Map<tir::Var, arith::IntSet>* dom_map) {
- // Use the attribute-annotated TIR var upper bounds as the TIR var values for
+void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana,
+ ffi::Map<tir::Var, arith::IntSet>* dom_map) {
+ // Use the attribute-annotated TIR var bounds as the TIR var values for
// memory planning.
- // NOTE: we only apply the annotated upper bounds to the TIR variables that
+ // NOTE: we only apply the annotated bounds to the TIR variables that
// appear in the **function signature**.
ffi::Map<ffi::String, IntImm> var_upper_bound_attr_raw =
func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_upper_bound")
.value_or(ffi::Map<ffi::String, IntImm>());
+ ffi::Map<ffi::String, IntImm> var_lower_bound_attr_raw =
+ func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_lower_bound")
+ .value_or(ffi::Map<ffi::String, IntImm>());
ffi::Array<ffi::String> non_negative_var_attr_raw =
func->GetAttr<ffi::Array<ffi::String>>("tir_non_negative_var")
.value_or(ffi::Array<ffi::String>());
std::unordered_map<ffi::String, IntImm> var_upper_bound_attr;
+ std::unordered_map<ffi::String, IntImm> var_lower_bound_attr;
std::unordered_set<ffi::String> non_negative_var_attr;
// We manually check the value type to ensure the values are all positive
IntImm.
for (auto [key, value] : var_upper_bound_attr_raw) {
var_upper_bound_attr[key] = value;
}
+ for (auto [key, value] : var_lower_bound_attr_raw) {
+ var_lower_bound_attr[key] = value;
+ }
for (const ffi::String& var_name : non_negative_var_attr_raw) {
non_negative_var_attr.insert(var_name);
}
ffi::Array<tir::Var> var_in_signature =
TIRVarsInStructInfo(GetStructInfo(func));
for (const tir::Var& tir_var : var_in_signature) {
- auto it = var_upper_bound_attr.find(tir_var->name_hint);
- if (it != var_upper_bound_attr.end()) {
- tvm::Range range =
- tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
- tvm::IntImm(DataType::Int(64),
(*it).second->value + 1));
+ auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
+ auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);
+
+ if (it_upper != var_upper_bound_attr.end() || it_lower !=
var_lower_bound_attr.end()) {
+ int64_t lower = (it_lower != var_lower_bound_attr.end()) ?
it_lower->second->value : 0;
+ int64_t upper = (it_upper != var_upper_bound_attr.end())
+ ? it_upper->second->value
+ : std::numeric_limits<int64_t>::max();
+ tvm::Range range = tvm::Range::FromMinExtent(
+ tvm::IntImm(DataType::Int(64), lower),
tvm::IntImm(DataType::Int(64), upper - lower + 1));
ana->Bind(tir_var, range);
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
} else if (non_negative_var_attr.count(tir_var->name_hint)) {
@@ -485,8 +497,8 @@ class StorageAllocatorInit : public
StorageAllocatorBaseVisitor {
: ctx_mod_(ctx_mod), analyzer_(analyzer) {}
void VisitExpr_(const FunctionNode* func) final {
- // Set the upper bound of TIR variables in the analyzer.
- SetTIRVarUpperBound(ffi::GetRef<Function>(func), analyzer_, &dom_map_);
+ // Set the range constraints of TIR variables in the analyzer.
+ SetTIRVarRangeConstraints(ffi::GetRef<Function>(func), analyzer_,
&dom_map_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are
external referenced.
@@ -843,7 +855,7 @@ class StorageAllocationRewriter : public ExprMutator {
plan_dynamic_output_ = static_cast<bool>(
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32),
0))->value);
if (plan_dynamic_output_) {
- SetTIRVarUpperBound(ffi::GetRef<Function>(func_), &ana_, &dom_map_);
+ SetTIRVarRangeConstraints(ffi::GetRef<Function>(func_), &ana_,
&dom_map_);
}
token2storage_var_.clear();
Function func = Downcast<Function>(this->VisitExpr_(func_));
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 71e400a6a8..157af43fac 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6747,7 +6747,7 @@ def test_dynamic_shape_with_range_constraints():
x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4),
dtype="float32")
) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
s0 = T.int64(is_size_var=True)
- R.func_attr({"tir_var_upper_bound": {"s0": 64}})
+ R.func_attr({"tir_var_lower_bound": {"s0": 1},
"tir_var_upper_bound": {"s0": 64}})
with R.dataflow():
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 83e4d264c6..06e4ea142e 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1347,6 +1347,18 @@ def test_invalid_tir_var_upper_bound():
relax.transform.StaticPlanBlockMemory()(Module)
+def test_invalid_tir_var_lower_bound():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")):
+ R.func_attr({"tir_var_lower_bound": {"n": [4]},
"relax.force_pure": True})
+ return x
+
+ with pytest.raises((TVMError, TypeError)):
+ relax.transform.StaticPlanBlockMemory()(Module)
+
+
def test_add():
@I.ir_module
class Module: