This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ce0ac662fe [Relax][PyTorch] Add lower bound support for range 
constraints (#18447)
ce0ac662fe is described below

commit ce0ac662fe2bb26f85b43c1bcfe2deb0620aada7
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 14 19:51:49 2025 +0800

    [Relax][PyTorch] Add lower bound support for range constraints (#18447)
    
    Add lower bound support for range constraints
---
 include/tvm/relax/transform.h                      | 19 ++++++-----
 .../frontend/torch/exported_program_translator.py  |  6 ++--
 src/relax/transform/adjust_matmul_order.cc         | 32 ++++++++++++++----
 src/relax/transform/static_plan_block_memory.cc    | 38 ++++++++++++++--------
 .../relax/test_frontend_from_exported_program.py   |  2 +-
 .../test_transform_static_plan_block_memory.py     | 12 +++++++
 6 files changed, 77 insertions(+), 32 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index a8ccc4076b..58cf7421b5 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -125,18 +125,19 @@ TVM_DLL Pass RewriteDataflowReshape();
  * The pass will reuse allocated memory to its best effort, in order to
  * reduce the total amount of allocated memory size.
  *
- * The pass "supports" dynamic shape in the way of TIR variable upper bound
- * annotation. We can optionally annotate the attribute "tir_var_upper_bound"
- * to Relax functions. The attribute value is a dict from strings to integers,
- * denoting the name of TIR variables to the upper bound values of the TIR 
vars.
- * Note: The annotated upper bound attribute only applies to TIR vars in the
+ * The pass "supports" dynamic shape in the way of TIR variable bound
+ * annotations. We can optionally annotate the attributes "tir_var_upper_bound"
+ * and "tir_var_lower_bound" to Relax functions. The attribute values are dicts
+ * from strings to integers, denoting the name of TIR variables to the bound
+ * values of the TIR vars.
+ * Note: The annotated bound attributes only apply to TIR vars in the
  * function signature for clarity.
  *
  * For example, we can annotate a Relax function with
- *   `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
- * It means the maximum value of variable that names "n" in the function
- * signature will have upper bound 1024. And we will use 1024 as its value
- * during memory planning.
+ *   `R.func_attr({"tir_var_lower_bound": {"n": 1}, "tir_var_upper_bound": 
{"n": 1024}})`.
+ * It means the variable that names "n" in the function signature will have
+ * range [1, 1024]. And we will use these bounds during memory planning.
+ * If lower bound is not specified, it defaults to 0.
  *
  * \return The pass.
  */
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 5cddf24a89..431a1444d1 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1181,10 +1181,12 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         if range_constraints:
             if func_attrs is None:
                 func_attrs = {}
-            tir_var_upper_bound = {
+            func_attrs["tir_var_lower_bound"] = {
+                var_name: lower for var_name, (lower, _) in 
range_constraints.items()
+            }
+            func_attrs["tir_var_upper_bound"] = {
                 var_name: upper for var_name, (_, upper) in 
range_constraints.items()
             }
-            func_attrs["tir_var_upper_bound"] = tir_var_upper_bound
 
         nodes: List[fx.Node] = exported_program.graph.nodes
 
diff --git a/src/relax/transform/adjust_matmul_order.cc 
b/src/relax/transform/adjust_matmul_order.cc
index 98fe57e11c..8892720191 100644
--- a/src/relax/transform/adjust_matmul_order.cc
+++ b/src/relax/transform/adjust_matmul_order.cc
@@ -73,19 +73,37 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, 
ffi::Map<DFPattern, Expr>)>>
              pat_permuted_matmul_on_rhs;
 
   PrimExpr symbolic_var_constraints = Bool(true);
-  if (auto upper_bounds = func->GetAttr<ffi::Map<ffi::String, 
Any>>("tir_var_upper_bound")) {
+  auto upper_bounds = func->GetAttr<ffi::Map<ffi::String, 
Any>>("tir_var_upper_bound");
+  auto lower_bounds = func->GetAttr<ffi::Map<ffi::String, 
Any>>("tir_var_lower_bound");
+
+  if (upper_bounds || lower_bounds) {
     ffi::Map<ffi::String, tir::Var> name_lookup;
     for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) {
       name_lookup.Set(tir_var->name_hint, tir_var);
       symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var);
     }
 
-    for (const auto& [key, obj_bound] : upper_bounds.value()) {
-      auto tir_var_name = Downcast<ffi::String>(key);
-      if (auto opt_var = name_lookup.Get(tir_var_name)) {
-        auto var = opt_var.value();
-        auto expr_bound = Downcast<PrimExpr>(obj_bound);
-        symbolic_var_constraints = symbolic_var_constraints && (var < 
expr_bound);
+    // Add lower bound constraints
+    if (lower_bounds) {
+      for (const auto& [key, obj_bound] : lower_bounds.value()) {
+        auto tir_var_name = Downcast<ffi::String>(key);
+        if (auto opt_var = name_lookup.Get(tir_var_name)) {
+          auto var = opt_var.value();
+          auto expr_bound = Downcast<PrimExpr>(obj_bound);
+          symbolic_var_constraints = symbolic_var_constraints && (expr_bound 
<= var);
+        }
+      }
+    }
+
+    // Add upper bound constraints
+    if (upper_bounds) {
+      for (const auto& [key, obj_bound] : upper_bounds.value()) {
+        auto tir_var_name = Downcast<ffi::String>(key);
+        if (auto opt_var = name_lookup.Get(tir_var_name)) {
+          auto var = opt_var.value();
+          auto expr_bound = Downcast<PrimExpr>(obj_bound);
+          symbolic_var_constraints = symbolic_var_constraints && (var < 
expr_bound);
+        }
       }
     }
   }
diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index 85076206ae..fc3c2259ff 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -365,40 +365,52 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
 };
 
 /*!
- * \brief Set the upper bound of the TIR variables that appear in
+ * \brief Set the range constraints of the TIR variables that appear in
  * the input function signature in the analyzer.
  * \param func The function to be analyzed.
  * \param ana The analyzer which contains the TIR var upper bounds.
  * \param dom_map The domain map of the TIR variables.
  */
-void SetTIRVarUpperBound(Function func, arith::Analyzer* ana,
-                         ffi::Map<tir::Var, arith::IntSet>* dom_map) {
-  // Use the attribute-annotated TIR var upper bounds as the TIR var values for
+void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana,
+                               ffi::Map<tir::Var, arith::IntSet>* dom_map) {
+  // Use the attribute-annotated TIR var bounds as the TIR var values for
   // memory planning.
-  // NOTE: we only apply the annotated upper bounds to the TIR variables that
+  // NOTE: we only apply the annotated bounds to the TIR variables that
   // appear in the **function signature**.
   ffi::Map<ffi::String, IntImm> var_upper_bound_attr_raw =
       func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_upper_bound")
           .value_or(ffi::Map<ffi::String, IntImm>());
+  ffi::Map<ffi::String, IntImm> var_lower_bound_attr_raw =
+      func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_lower_bound")
+          .value_or(ffi::Map<ffi::String, IntImm>());
   ffi::Array<ffi::String> non_negative_var_attr_raw =
       func->GetAttr<ffi::Array<ffi::String>>("tir_non_negative_var")
           .value_or(ffi::Array<ffi::String>());
   std::unordered_map<ffi::String, IntImm> var_upper_bound_attr;
+  std::unordered_map<ffi::String, IntImm> var_lower_bound_attr;
   std::unordered_set<ffi::String> non_negative_var_attr;
   // We manually check the value type to ensure the values are all positive 
IntImm.
   for (auto [key, value] : var_upper_bound_attr_raw) {
     var_upper_bound_attr[key] = value;
   }
+  for (auto [key, value] : var_lower_bound_attr_raw) {
+    var_lower_bound_attr[key] = value;
+  }
   for (const ffi::String& var_name : non_negative_var_attr_raw) {
     non_negative_var_attr.insert(var_name);
   }
   ffi::Array<tir::Var> var_in_signature = 
TIRVarsInStructInfo(GetStructInfo(func));
   for (const tir::Var& tir_var : var_in_signature) {
-    auto it = var_upper_bound_attr.find(tir_var->name_hint);
-    if (it != var_upper_bound_attr.end()) {
-      tvm::Range range =
-          tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
-                                    tvm::IntImm(DataType::Int(64), 
(*it).second->value + 1));
+    auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
+    auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);
+
+    if (it_upper != var_upper_bound_attr.end() || it_lower != 
var_lower_bound_attr.end()) {
+      int64_t lower = (it_lower != var_lower_bound_attr.end()) ? 
it_lower->second->value : 0;
+      int64_t upper = (it_upper != var_upper_bound_attr.end())
+                          ? it_upper->second->value
+                          : std::numeric_limits<int64_t>::max();
+      tvm::Range range = tvm::Range::FromMinExtent(
+          tvm::IntImm(DataType::Int(64), lower), 
tvm::IntImm(DataType::Int(64), upper - lower + 1));
       ana->Bind(tir_var, range);
       dom_map->Set(tir_var, arith::IntSet::FromRange(range));
     } else if (non_negative_var_attr.count(tir_var->name_hint)) {
@@ -485,8 +497,8 @@ class StorageAllocatorInit : public 
StorageAllocatorBaseVisitor {
       : ctx_mod_(ctx_mod), analyzer_(analyzer) {}
 
   void VisitExpr_(const FunctionNode* func) final {
-    // Set the upper bound of TIR variables in the analyzer.
-    SetTIRVarUpperBound(ffi::GetRef<Function>(func), analyzer_, &dom_map_);
+    // Set the range constraints of TIR variables in the analyzer.
+    SetTIRVarRangeConstraints(ffi::GetRef<Function>(func), analyzer_, 
&dom_map_);
     // Recurse into the function to get its tokens.
     Tokens body_tokens = GetTokens(func->body);
     // Discard the tokens used by the function return value, as they are 
external referenced.
@@ -843,7 +855,7 @@ class StorageAllocationRewriter : public ExprMutator {
       plan_dynamic_output_ = static_cast<bool>(
           
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 
0))->value);
       if (plan_dynamic_output_) {
-        SetTIRVarUpperBound(ffi::GetRef<Function>(func_), &ana_, &dom_map_);
+        SetTIRVarRangeConstraints(ffi::GetRef<Function>(func_), &ana_, 
&dom_map_);
       }
       token2storage_var_.clear();
       Function func = Downcast<Function>(this->VisitExpr_(func_));
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 71e400a6a8..157af43fac 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6747,7 +6747,7 @@ def test_dynamic_shape_with_range_constraints():
             x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), 
dtype="float32")
         ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
             s0 = T.int64(is_size_var=True)
-            R.func_attr({"tir_var_upper_bound": {"s0": 64}})
+            R.func_attr({"tir_var_lower_bound": {"s0": 1}, 
"tir_var_upper_bound": {"s0": 64}})
             with R.dataflow():
                 lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
                 gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 83e4d264c6..06e4ea142e 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1347,6 +1347,18 @@ def test_invalid_tir_var_upper_bound():
         relax.transform.StaticPlanBlockMemory()(Module)
 
 
+def test_invalid_tir_var_lower_bound():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")):
+            R.func_attr({"tir_var_lower_bound": {"n": [4]}, 
"relax.force_pure": True})
+            return x
+
+    with pytest.raises((TVMError, TypeError)):
+        relax.transform.StaticPlanBlockMemory()(Module)
+
+
 def test_add():
     @I.ir_module
     class Module:

Reply via email to