This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c866abc43c [Relax] Fix wrong memory planning when only lower bound was
provided (#18663)
c866abc43c is described below
commit c866abc43c5525b78850d6fd24fba0d250c2949c
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Jan 17 01:01:27 2026 +0900
[Relax] Fix wrong memory planning when only lower bound was provided
(#18663)
This PR fixes an issue in StaticPlanBlockMemory where dynamic shapes
were incorrectly planned as static memory when only a lower bound was
provided for TIR variables.
Repro:
<details>
<summary>repro_dynamic_memory_plan.py</summary>
```python
import tvm
from tvm import relax, testing
from tvm.relax.frontend.torch import from_exported_program
from torch.export import Dim, export
import torch
class SimpleConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
def forward(self, x):
return self.conv(x)
def main():
model = SimpleConv().eval()
example = torch.randn(2, 3, 32, 32)
batch = Dim("batch") # No max= specified, so upper bound is unknown
exported = export(model, (example,), dynamic_shapes={"x": {0: batch}})
mod = from_exported_program(exported)
mod = relax.transform.DecomposeOpsForInference()(mod)
target = tvm.target.Target("llvm")
exe = tvm.compile(mod, target=target)
vm = relax.VirtualMachine(exe, tvm.cpu())
inp = tvm.runtime.from_dlpack(example)
out = vm["main"](inp)
expected = model(example).detach().numpy()
actual = out[0].numpy()
testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)
if __name__ == "__main__":
main()
```
</details>
This will fail with the following error.
<details>
<summary>output</summary>
```
$ uv run python repro_dynamic_memory_plan.py
/home/ubuntu/data/project/tvm-example/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:182:
UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount().
Did you run some cuda functions before calling NumCudaDevices() that might have
already set an error? Error 804: forward compatibility was attempted on non
supported HW (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:119.)
return torch._C._cuda_getDeviceCount() > 0
Traceback (most recent call last):
File
"/home/ubuntu/data/project/tvm-example/frontend/repro_dynamic_memory_plan.py",
line 40, in <module>
main()
File
"/home/ubuntu/data/project/tvm-example/frontend/repro_dynamic_memory_plan.py",
line 32, in main
out = vm["main"](inp)
^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 923, in
tvm_ffi.core.Function.__call__
File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc",
line 549, in
tvm::runtime::vm::VirtualMachineImpl::InvokeClosurePacked(tvm::ffi::ObjectRef
const&, tvm::ffi::PackedArgs, tvm::ffi::Any*)
clo->impl.CallPacked(ffi::PackedArgs(packed_args.data(),
packed_args.size()), rv);
File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc",
line 622, in operator()
*rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx,
inputs);
File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc",
line 693, in tvm::runtime::vm::VirtualMachineImpl::InvokeBytecode(long,
std::vector<tvm::ffi::Any, std::allocator<tvm::ffi::Any> > const&)
RunLoop();
File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc",
line 816, in tvm::runtime::vm::VirtualMachineImpl::RunLoop()
this->RunInstrCall(curr_frame, instr);
File "/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/vm.cc",
line 767, in
tvm::runtime::vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::vm::VMFrame*,
tvm::runtime::vm::Instruction)
this->InvokeClosurePacked(func_pool_[instr.func_idx].cast<ObjectRef>(),
args, &ret);
File
"/home/ubuntu/data/project/tvm-example/tvm/src/runtime/vm/builtin.cc", line
405, in operator()
*rv = sobj->AllocTensor(offset, shape, dtype);
File
"/home/ubuntu/data/project/tvm-example/tvm/src/runtime/memory/memory_manager.cc",
line 98, in tvm::runtime::memory::StorageObj::AllocTensor(long,
tvm::ffi::Shape, DLDataType)
ICHECK(offset + needed_size <= this->buffer.size)
File
"/home/ubuntu/data/project/tvm-example/tvm/include/tvm/runtime/logging.h", line
321, in tvm::runtime::detail::LogFatal::~LogFatal()
GetEntry().Finalize();
File
"/home/ubuntu/data/project/tvm-example/tvm/include/tvm/runtime/logging.h", line
337, in tvm::runtime::detail::LogFatal::Entry::Finalize()
InternalError error(file_, lineno_, stream_.str());
tvm.error.InternalError: Check failed: (offset + needed_size <=
this->buffer.size) is false: storage allocation failure, attempted to allocate
524288 at offset 0 in region that is 262144bytes
```
</details>
---
src/relax/transform/static_plan_block_memory.cc | 11 +-
.../test_transform_static_plan_block_memory.py | 263 +++++++++++++++++++--
2 files changed, 246 insertions(+), 28 deletions(-)
diff --git a/src/relax/transform/static_plan_block_memory.cc
b/src/relax/transform/static_plan_block_memory.cc
index cebbaa4ce5..2a41687983 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -437,15 +437,18 @@ void SetTIRVarRangeConstraints(Function func,
arith::Analyzer* ana,
auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);
- if (it_upper != var_upper_bound_attr.end() || it_lower !=
var_lower_bound_attr.end()) {
+ // Only bind the variable to a range if an upper bound is explicitly
provided.
+ // Without an upper bound, memory planning cannot determine the required
storage size,
+ // so we skip binding and let the variable remain unbounded.
+ if (it_upper != var_upper_bound_attr.end()) {
int64_t lower = (it_lower != var_lower_bound_attr.end()) ?
it_lower->second->value : 0;
- int64_t upper = (it_upper != var_upper_bound_attr.end())
- ? it_upper->second->value
- : std::numeric_limits<int64_t>::max();
+ int64_t upper = it_upper->second->value;
tvm::Range range = tvm::Range::FromMinExtent(
tvm::IntImm(DataType::Int(64), lower),
tvm::IntImm(DataType::Int(64), upper - lower + 1));
ana->Bind(tir_var, range);
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
+ } else if (it_lower != var_lower_bound_attr.end() &&
it_lower->second->value >= 0) {
+ ana->MarkGlobalNonNegValue(tir_var);
} else if (non_negative_var_attr.count(tir_var->name_hint)) {
ana->MarkGlobalNonNegValue(tir_var);
}
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 06e4ea142e..87c6f12f53 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1018,6 +1018,245 @@ def test_tir_var_upper_bound():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_lower_bound_only():
+ # fmt: off
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add:
T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.handle, PadInput: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n +
2",), dtype="float32"):
+ R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure":
True})
+ n = T.int64()
+ cls = Module
+ alloc: R.Tensor((2, n), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+ _: R.Tuple() = cls.exp(x, alloc)
+ lv: R.Tensor((2, n), dtype="float32") = alloc
+ lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+ alloc1: R.Tensor((2 * n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32",
runtime_device_index=0)
+ _1: R.Tuple() = cls.relu(lv1, alloc1)
+ lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((2 * n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32",
runtime_device_index=0)
+ _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+ alloc3: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32",
runtime_device_index=0)
+ _3: R.Tuple() = cls.pad(lv3, alloc3)
+ lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+ alloc4: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+ _4: R.Tuple() = cls.log(lv4, alloc4)
+ gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add:
T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.handle, PadInput: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n +
2",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure":
True})
+ cls = Expected
+ storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc: R.Tensor((2, n), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]),
R.dtype("float32"), R.prim_value(0))
+ _: R.Tuple = cls.exp(x, alloc)
+ lv: R.Tensor((2, n), dtype="float32") = alloc
+ lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv,
R.shape([2 * n]))
+ storage1: R.Object = R.memory.alloc_storage(R.shape([4 * (2 *
n)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc1: R.Tensor((2 * n,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]),
R.dtype("float32"))
+ _1: R.Tuple = cls.relu(lv1, alloc1)
+ lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((2 * n,), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]),
R.dtype("float32"))
+ _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+ storage2: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n +
2)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc3: R.Tensor((2 * n + 2,), dtype="float32") =
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([2 * n + 2]),
R.dtype("float32"), R.prim_value(0))
+ _3: R.Tuple = cls.pad(lv3, alloc3)
+ lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+ alloc4: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+ _4: R.Tuple = cls.log(lv4, alloc4)
+ gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+ return gv
+ # fmt: on
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_upper_and_lower_bounds():
+ # fmt: off
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add:
T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.handle, PadInput: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n +
2",), dtype="float32"):
+ R.func_attr({"tir_var_upper_bound": {"n": 4},
"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
+ n = T.int64()
+ cls = Module
+ alloc: R.Tensor((2, n), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+ _: R.Tuple() = cls.exp(x, alloc)
+ lv: R.Tensor((2, n), dtype="float32") = alloc
+ lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+ alloc1: R.Tensor((2 * n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32",
runtime_device_index=0)
+ _1: R.Tuple() = cls.relu(lv1, alloc1)
+ lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((2 * n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32",
runtime_device_index=0)
+ _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+ alloc3: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32",
runtime_device_index=0)
+ _3: R.Tuple() = cls.pad(lv3, alloc3)
+ lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+ alloc4: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+ _4: R.Tuple() = cls.log(lv4, alloc4)
+ gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add:
T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.handle, PadInput: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n +
2",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_upper_bound": {"n": 4},
"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
+ cls = Expected
+ storage: R.Object = R.memory.alloc_storage(R.shape([32]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc: R.Tensor((2, n), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]),
R.dtype("float32"))
+ _: R.Tuple = cls.exp(x, alloc)
+ lv: R.Tensor((2, n), dtype="float32") = alloc
+ lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv,
R.shape([2 * n]))
+ storage1: R.Object = R.memory.alloc_storage(R.shape([40]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc1: R.Tensor((2 * n,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]),
R.dtype("float32"))
+ _1: R.Tuple = cls.relu(lv1, alloc1)
+ lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((2 * n,), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]),
R.dtype("float32"))
+ _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+ alloc3: R.Tensor((2 * n + 2,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n + 2]),
R.dtype("float32"))
+ _3: R.Tuple = cls.pad(lv3, alloc3)
+ lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+ alloc4: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+ _4: R.Tuple = cls.log(lv4, alloc4)
+ gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+ return gv
+ # fmt: on
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_invalid_tir_var_upper_bound():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")):
+ R.func_attr({"tir_var_upper_bound": {"n": [4]},
"relax.force_pure": True})
+ return x
+
+ with pytest.raises((TVMError, TypeError)):
+ relax.transform.StaticPlanBlockMemory()(Module)
+
+
+def test_invalid_tir_var_lower_bound():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")):
+ R.func_attr({"tir_var_lower_bound": {"n": [4]},
"relax.force_pure": True})
+ return x
+
+ with pytest.raises((TVMError, TypeError)):
+ relax.transform.StaticPlanBlockMemory()(Module)
+
+
def test_tir_var_decreasing_monotone():
# fmt: off
@I.ir_module
@@ -1335,30 +1574,6 @@ def test_function_independence():
tvm.ir.assert_structural_equal(mod, Expected)
-def test_invalid_tir_var_upper_bound():
- @tvm.script.ir_module
- class Module:
- @R.function
- def main(x: R.Tensor((2, "n"), dtype="float32")):
- R.func_attr({"tir_var_upper_bound": {"n": [4]},
"relax.force_pure": True})
- return x
-
- with pytest.raises((TVMError, TypeError)):
- relax.transform.StaticPlanBlockMemory()(Module)
-
-
-def test_invalid_tir_var_lower_bound():
- @tvm.script.ir_module
- class Module:
- @R.function
- def main(x: R.Tensor((2, "n"), dtype="float32")):
- R.func_attr({"tir_var_lower_bound": {"n": [4]},
"relax.force_pure": True})
- return x
-
- with pytest.raises((TVMError, TypeError)):
- relax.transform.StaticPlanBlockMemory()(Module)
-
-
def test_add():
@I.ir_module
class Module: