This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c866abc43c [Relax] Fix wrong memory planning when only lower bound was 
provided (#18663)
c866abc43c is described below

commit c866abc43c5525b78850d6fd24fba0d250c2949c
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Jan 17 01:01:27 2026 +0900

    [Relax] Fix wrong memory planning when only lower bound was provided 
(#18663)
    
    This PR fixes an issue in StaticPlanBlockMemory where dynamic shapes
    were incorrectly planned as static memory when only a lower bound was
    provided for TIR variables.
    
    Repro:
    <details>
    <summary>repro_dynamic_memory_plan.py</summary>
    
    ```python
    import tvm
    from tvm import relax, testing
    from tvm.relax.frontend.torch import from_exported_program
    from torch.export import Dim, export
    import torch
    
    
    class SimpleConv(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
    
        def forward(self, x):
            return self.conv(x)
    
    
    def main():
        model = SimpleConv().eval()
    
        example = torch.randn(2, 3, 32, 32)
        batch = Dim("batch")  # No max= specified, so upper bound is unknown
        exported = export(model, (example,), dynamic_shapes={"x": {0: batch}})
    
        mod = from_exported_program(exported)
        mod = relax.transform.DecomposeOpsForInference()(mod)
    
        target = tvm.target.Target("llvm")
        exe = tvm.compile(mod, target=target)
    
        vm = relax.VirtualMachine(exe, tvm.cpu())
        inp = tvm.runtime.from_dlpack(example)
        out = vm["main"](inp)
    
        expected = model(example).detach().numpy()
        actual = out[0].numpy()
        testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)
    
    
    if __name__ == "__main__":
        main()
    ```
    
    </details>
    
    This will fail with the following error.
    
    <details>
    <summary>output</summary>
    
    ```
    $ uv run python repro_dynamic_memory_plan.py
    
/home/ubuntu/data/project/tvm-example/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:182:
 UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). 
Did you run some cuda functions before calling NumCudaDevices() that might have 
already set an error? Error 804: forward compatibility was attempted on non 
supported HW (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:119.)
      return torch._C._cuda_getDeviceCount() > 0
    Traceback (most recent call last):
      File 
"/home/ubuntu/data/project/tvm-example/frontend/repro_dynamic_memory_plan.py", 
line 40, in <module>
        main()
      File 
"/home/ubuntu/data/project/tvm-example/frontend/repro_dynamic_memory_plan.py", 
line 32, in main
        out = vm["main"](inp)
              ^^^^^^^^^^^^^^^
      File "python/tvm_ffi/cython/function.pxi", line 923, in 
tvm_ffi.core.Function.__call__
      File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", 
line 549, in 
tvm::runtime::vm::VirtualMachineImpl::InvokeClosurePacked(tvm::ffi::ObjectRef 
const&, tvm::ffi::PackedArgs, tvm::ffi::Any*)
        clo->impl.CallPacked(ffi::PackedArgs(packed_args.data(), 
packed_args.size()), rv);
    
      File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", 
line 622, in operator()
        *rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx, 
inputs);
    
      File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", 
line 693, in tvm::runtime::vm::VirtualMachineImpl::InvokeBytecode(long, 
std::vector<tvm::ffi::Any, std::allocator<tvm::ffi::Any> > const&)
        RunLoop();
    
      File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", 
line 816, in tvm::runtime::vm::VirtualMachineImpl::RunLoop()
        this->RunInstrCall(curr_frame, instr);
    
      File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc", 
line 767, in 
tvm::runtime::vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::vm::VMFrame*, 
tvm::runtime::vm::Instruction)
        this->InvokeClosurePacked(func_pool_[instr.func_idx].cast<ObjectRef>(), 
args, &ret);
    
      File 
"/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/builtin.cc", line 
405, in operator()
        *rv = sobj->AllocTensor(offset, shape, dtype);
    
      File 
"/home/ubuntu/data/project/tvm-example/tvm/src/runtime/memory/memory_manager.cc",
 line 98, in tvm::runtime::memory::StorageObj::AllocTensor(long, 
tvm::ffi::Shape, DLDataType)
        ICHECK(offset + needed_size <= this->buffer.size)
    
      File 
"/home/ubuntu/data/project/tvm-example/tvm/include/tvm/runtime/logging.h", line 
321, in tvm::runtime::detail::LogFatal::~LogFatal()
        GetEntry().Finalize();
    
      File 
"/home/ubuntu/data/project/tvm-example/tvm/include/tvm/runtime/logging.h", line 
337, in tvm::runtime::detail::LogFatal::Entry::Finalize()
        InternalError error(file_, lineno_, stream_.str());
    
    tvm.error.InternalError: Check failed: (offset + needed_size <= 
this->buffer.size) is false: storage allocation failure, attempted to allocate 
524288 at offset 0 in region that is 262144bytes
    ```
    
    </details>
---
 src/relax/transform/static_plan_block_memory.cc    |  11 +-
 .../test_transform_static_plan_block_memory.py     | 263 +++++++++++++++++++--
 2 files changed, 246 insertions(+), 28 deletions(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index cebbaa4ce5..2a41687983 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -437,15 +437,18 @@ void SetTIRVarRangeConstraints(Function func, 
arith::Analyzer* ana,
     auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
     auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);
 
-    if (it_upper != var_upper_bound_attr.end() || it_lower != 
var_lower_bound_attr.end()) {
+    // Only bind the variable to a range if an upper bound is explicitly 
provided.
+    // Without an upper bound, memory planning cannot determine the required 
storage size,
+    // so we skip binding and let the variable remain unbounded.
+    if (it_upper != var_upper_bound_attr.end()) {
       int64_t lower = (it_lower != var_lower_bound_attr.end()) ? 
it_lower->second->value : 0;
-      int64_t upper = (it_upper != var_upper_bound_attr.end())
-                          ? it_upper->second->value
-                          : std::numeric_limits<int64_t>::max();
+      int64_t upper = it_upper->second->value;
       tvm::Range range = tvm::Range::FromMinExtent(
           tvm::IntImm(DataType::Int(64), lower), 
tvm::IntImm(DataType::Int(64), upper - lower + 1));
       ana->Bind(tir_var, range);
       dom_map->Set(tir_var, arith::IntSet::FromRange(range));
+    } else if (it_lower != var_lower_bound_attr.end() && 
it_lower->second->value >= 0) {
+      ana->MarkGlobalNonNegValue(tir_var);
     } else if (non_negative_var_attr.count(tir_var->name_hint)) {
       ana->MarkGlobalNonNegValue(tir_var);
     }
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 06e4ea142e..87c6f12f53 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1018,6 +1018,245 @@ def test_tir_var_upper_bound():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_lower_bound_only():
+    # fmt: off
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": 
True})
+            n = T.int64()
+            cls = Module
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+            _: R.Tuple() = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _1: R.Tuple() = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", 
runtime_device_index=0)
+            _3: R.Tuple() = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+            _4: R.Tuple() = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": 
True})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), 
R.dtype("float32"), R.prim_value(0))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, 
R.shape([2 * n]))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * 
n)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            storage2: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n + 
2)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([2 * n + 2]), 
R.dtype("float32"), R.prim_value(0))
+            _3: R.Tuple = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+            _4: R.Tuple = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_upper_and_lower_bounds():
+    # fmt: off
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            R.func_attr({"tir_var_upper_bound": {"n": 4}, 
"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
+            n = T.int64()
+            cls = Module
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+            _: R.Tuple() = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _1: R.Tuple() = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", 
runtime_device_index=0)
+            _3: R.Tuple() = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+            _4: R.Tuple() = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_upper_bound": {"n": 4}, 
"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), 
R.dtype("float32"))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, 
R.shape([2 * n]))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([40]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n + 2]), 
R.dtype("float32"))
+            _3: R.Tuple = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+            _4: R.Tuple = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_invalid_tir_var_upper_bound():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")):
+            R.func_attr({"tir_var_upper_bound": {"n": [4]}, 
"relax.force_pure": True})
+            return x
+
+    with pytest.raises((TVMError, TypeError)):
+        relax.transform.StaticPlanBlockMemory()(Module)
+
+
+def test_invalid_tir_var_lower_bound():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")):
+            R.func_attr({"tir_var_lower_bound": {"n": [4]}, 
"relax.force_pure": True})
+            return x
+
+    with pytest.raises((TVMError, TypeError)):
+        relax.transform.StaticPlanBlockMemory()(Module)
+
+
 def test_tir_var_decreasing_monotone():
     # fmt: off
     @I.ir_module
@@ -1335,30 +1574,6 @@ def test_function_independence():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-def test_invalid_tir_var_upper_bound():
-    @tvm.script.ir_module
-    class Module:
-        @R.function
-        def main(x: R.Tensor((2, "n"), dtype="float32")):
-            R.func_attr({"tir_var_upper_bound": {"n": [4]}, 
"relax.force_pure": True})
-            return x
-
-    with pytest.raises((TVMError, TypeError)):
-        relax.transform.StaticPlanBlockMemory()(Module)
-
-
-def test_invalid_tir_var_lower_bound():
-    @tvm.script.ir_module
-    class Module:
-        @R.function
-        def main(x: R.Tensor((2, "n"), dtype="float32")):
-            R.func_attr({"tir_var_lower_bound": {"n": [4]}, 
"relax.force_pure": True})
-            return x
-
-    with pytest.raises((TVMError, TypeError)):
-        relax.transform.StaticPlanBlockMemory()(Module)
-
-
 def test_add():
     @I.ir_module
     class Module:

Reply via email to