This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b07c6c1e75 [Relax] Add symbolic shape support to BasePyModule for
dynamic tensor operations (#18288)
b07c6c1e75 is described below
commit b07c6c1e7502536db6c6f8c8696cc4c7f6bc46a1
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Sep 10 17:33:49 2025 -0400
[Relax] Add symbolic shape support to BasePyModule for dynamic tensor
operations (#18288)
This PR adds symbolic shape support to `BasePyModule`, which enables
dynamic tensor operations with runtime shape inference.
This allows users to use Relax's symbolic shape functionality in Python
function calls through BasePyModule, with dimensions automatically
resolved at execution time based on input tensor shapes.
## Usage Example
```python
import tvm
from tvm.script import ir as I, relax as R
from tvm.relax.base_py_module import BasePyModule
import numpy as np
@I.ir_module
class VectorAddModule(BasePyModule):
@R.function
def add(x: R.Tensor(("n",), "float32"),
y: R.Tensor(("n",), "float32")) -> R.Tensor(("n",), "float32"):
return R.add(x, y)
module = VectorAddModule(device=tvm.cpu(0), target="llvm")
a = np.array([1.0, 2.0, 3.0], dtype="float32")
b = np.array([4.0, 5.0, 6.0], dtype="float32")
result = module.add(a, b) # Result: [5.0, 7.0, 9.0]
```
---
python/tvm/relax/base_py_module.py | 68 +++-
.../relax/test_base_py_module_symbolic_shape.py | 367 +++++++++++++++++++++
2 files changed, 425 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 796ab41a14..eb34ca4d15 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -198,7 +198,7 @@ class BasePyModule:
)
func = self.compiled_tir_funcs[func_name]
- out = self._create_output_tensors(out_sinfo)
+ out = self._create_output_tensors(out_sinfo, args)
tvm_args = self._convert_pytorch_to_tvm(args)
tvm_out = self._convert_pytorch_to_tvm(out)
@@ -222,12 +222,11 @@ class BasePyModule:
) from error
func = self.extern_funcs[func_name]
- out = self._create_output_tensors(out_sinfo)
+ out = self._create_output_tensors(out_sinfo, args)
tvm_args = self._convert_pytorch_to_tvm(args)
tvm_out = self._convert_pytorch_to_tvm(out)
func(*tvm_args, *tvm_out)
- result = self._convert_tvm_to_pytorch(tvm_out)
- return result[0] if len(result) == 1 else result
+ return out[0] if len(out) == 1 else out
def call_py_func(self, func_name: str, args):
"""Call a Python function stored in the IRModule's pyfuncs."""
@@ -237,22 +236,71 @@ class BasePyModule:
converted_args = self._convert_tvm_to_pytorch(args)
return py_func(*converted_args)
- def _create_output_tensors(self, out_sinfo):
- """Create output PyTorch tensors based on shape and type
information."""
+ def _create_output_tensors(self, out_sinfo, in_args=None):
# pylint: disable=import-outside-toplevel
import torch
sinfo_list = out_sinfo if isinstance(out_sinfo, list) else [out_sinfo]
out_tensors = []
for sinfo in sinfo_list:
+ if isinstance(sinfo, (tuple, list)) and all(
+ isinstance(x, (int, np.integer)) for x in sinfo
+ ):
+ out_tensors.append(torch.zeros(list(map(int, sinfo)),
dtype=torch.float32))
+ continue
+
if hasattr(sinfo, "shape") and hasattr(sinfo, "dtype"):
- shape = [int(val) for val in sinfo.shape]
+ concrete_shape =
self._infer_concrete_shape_from_args(sinfo.shape, in_args)
torch_dtype = self._convert_tvm_dtype_to_torch(sinfo.dtype)
- out_tensors.append(torch.empty(shape, dtype=torch_dtype))
- else:
- out_tensors.append(torch.empty((1,), dtype=torch.float32))
+ out_tensors.append(torch.zeros(concrete_shape,
dtype=torch_dtype))
+ continue
+
+ out_tensors.append(torch.zeros((1,), dtype=torch.float32))
return out_tensors
+ def _infer_concrete_shape_from_args(self, shape, in_args):
+
+ concrete = []
+ symbolic_positions = []
+ for idx, dim in enumerate(shape):
+ if isinstance(dim, (int, np.integer)):
+ concrete.append(int(dim))
+ elif isinstance(dim, tir.IntImm):
+ concrete.append(int(dim.value))
+ else:
+ concrete.append(None)
+ symbolic_positions.append(idx)
+
+ if not symbolic_positions:
+ return concrete
+
+ candidates = []
+ if in_args is not None:
+ if not isinstance(in_args, (list, tuple)):
+ in_args = [in_args]
+ for obj in in_args:
+ if hasattr(obj, "shape") and isinstance(obj.shape, (tuple,
list)):
+ try:
+ candidates.append(tuple(int(x) for x in obj.shape))
+ continue
+ except (ValueError, TypeError):
+ # Skip objects with invalid shapes
+ pass
+
+ target_ndim = len(shape)
+ for cand in candidates:
+ if len(cand) == target_ndim:
+ for pos in symbolic_positions:
+ concrete[pos] = cand[pos]
+ if all(x is not None for x in concrete):
+ return concrete
+
+ raise ValueError(
+ "Cannot infer concrete output shape from symbolic shape and
inputs. "
+ "Please provide a concrete `out_sinfo` (e.g., a tuple/list of
ints) "
+ "or ensure input tensors carry shapes that determine output
extents."
+ )
+
def _convert_tvm_dtype_to_torch(self, tvm_dtype: str) -> "torch.dtype":
"""Convert TVM dtype string to PyTorch dtype."""
# pylint: disable=import-outside-toplevel
diff --git a/tests/python/relax/test_base_py_module_symbolic_shape.py
b/tests/python/relax/test_base_py_module_symbolic_shape.py
new file mode 100644
index 0000000000..aa39fe14bf
--- /dev/null
+++ b/tests/python/relax/test_base_py_module_symbolic_shape.py
@@ -0,0 +1,367 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+
+import tvm
+from tvm.ir import IRModule
+from tvm.relax.base_py_module import BasePyModule
+from tvm import tir, relax
+from tvm.script import ir as I, tir as T, relax as R
+
+
+def _make_module():
+ return IRModule({})
+
+
+def test_infer_concrete_shape_from_numpy_input():
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ sym_shape = [n, m]
+
+ x = np.zeros((3, 4), dtype="float32")
+ inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x])
+ assert inferred == [3, 4]
+
+
+def test_infer_concrete_shape_all_concrete_dims():
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ shape = [tir.IntImm("int32", 5), 6]
+ inferred = bpm._infer_concrete_shape_from_args(shape, in_args=[])
+ assert inferred == [5, 6]
+
+
+def test_infer_concrete_shape_error_when_uninferrable():
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ k = tir.Var("k", "int64")
+ with pytest.raises(ValueError):
+ bpm._infer_concrete_shape_from_args([k, 8], in_args=[])
+
+
[email protected]_module
+class AddModuleSymbolic(BasePyModule):
+ @T.prim_func
+ def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+ T.func_attr({"global_symbol": "add_tir"})
+ n = T.int64()
+ x = T.match_buffer(var_x, (n,), dtype="float32")
+ y = T.match_buffer(var_y, (n,), dtype="float32")
+ out = T.match_buffer(var_out, (n,), dtype="float32")
+
+ for i in T.serial(n):
+ out[i] = x[i] + y[i]
+
+ @R.function
+ def main_relax(
+ x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32")
+ ) -> R.Tensor(("n",), "float32"):
+ return R.add(x, y)
+
+
+def test_base_py_module_relax_symbolic_end_to_end():
+ bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm")
+
+ a = np.random.randn(5).astype("float32")
+ b = np.random.randn(5).astype("float32")
+ out = bpm.main_relax(a, b)
+ assert isinstance(out, np.ndarray) or hasattr(out, "numpy")
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ np.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6)
+
+ a7 = np.random.randn(7).astype("float32")
+ b7 = np.random.randn(7).astype("float32")
+ out2 = bpm.main_relax(a7, b7)
+ out2_np = out2 if isinstance(out2, np.ndarray) else out2.numpy()
+ np.testing.assert_allclose(out2_np, a7 + b7, rtol=1e-6, atol=1e-6)
+
+
+def test_base_py_module_tir_symbolic_end_to_end():
+ bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm")
+
+ a = np.random.randn(5).astype("float32")
+ b = np.random.randn(5).astype("float32")
+
+ n = tir.Var("n", "int64")
+ out_sinfo = relax.TensorStructInfo((n,), "float32")
+
+ out = bpm.call_tir("add_tir", [a, b], out_sinfo)
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ np.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6)
+
+
+def test_infer_concrete_shape_multiple_symbolic_dims():
+ """Test shape inference with multiple symbolic dimensions."""
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ k = tir.Var("k", "int64")
+ sym_shape = [n, m, k]
+
+ x = np.zeros((2, 3, 4), dtype="float32")
+ inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x])
+ assert inferred == [2, 3, 4]
+
+
+def test_infer_concrete_shape_mixed_concrete_symbolic():
+ """Test shape inference with mixed concrete and symbolic dimensions."""
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ sym_shape = [n, 5, 10] # First dim is symbolic, others are concrete
+
+ x = np.zeros((3, 5, 10), dtype="float32")
+ inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x])
+ assert inferred == [3, 5, 10]
+
+
+def test_infer_concrete_shape_from_tvm_tensors():
+ """Test shape inference from TVM tensors."""
+ try:
+ # Try to create TVM tensor using new API
+ x_np = np.zeros((3, 4), dtype="float32")
+ x_tvm = tvm.runtime.tensor(x_np)
+
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ sym_shape = [n, m]
+
+ inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_tvm])
+ assert inferred == [3, 4]
+ except AttributeError:
+ # Skip if tvm.runtime.tensor is not available
+ pytest.skip("tvm.runtime.tensor not available")
+
+
+def test_infer_concrete_shape_multiple_inputs():
+ """Test shape inference when multiple inputs are available."""
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ sym_shape = [n, m]
+
+ # Multiple inputs with different shapes - should use first matching one
+ x1 = np.zeros((2, 3), dtype="float32")
+ x2 = np.zeros((4, 5), dtype="float32")
+ inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x1, x2])
+ assert inferred == [2, 3] # Should use first input
+
+
+def test_infer_concrete_shape_wrong_ndim():
+ """Test shape inference when input has wrong number of dimensions."""
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ sym_shape = [n, m] # 2D
+
+ x = np.zeros((3,), dtype="float32") # 1D - wrong ndim
+ with pytest.raises(ValueError, match="Cannot infer concrete output shape"):
+ bpm._infer_concrete_shape_from_args(sym_shape, [x])
+
+
[email protected]_module
+class MatrixModuleSymbolic(BasePyModule):
+ @T.prim_func
+ def matmul_tir(var_a: T.handle, var_b: T.handle, var_c: T.handle):
+ T.func_attr({"global_symbol": "matmul_tir"})
+ m = T.int64()
+ n = T.int64()
+ k = T.int64()
+ a = T.match_buffer(var_a, (m, k), dtype="float32")
+ b = T.match_buffer(var_b, (k, n), dtype="float32")
+ c = T.match_buffer(var_c, (m, n), dtype="float32")
+
+ for i in T.serial(m):
+ for j in T.serial(n):
+ c[i, j] = 0.0
+ for l in T.serial(k):
+ c[i, j] = c[i, j] + a[i, l] * b[l, j]
+
+ @R.function
+ def matmul_relax(
+ a: R.Tensor(("m", "k"), "float32"), b: R.Tensor(("k", "n"), "float32")
+ ) -> R.Tensor(("m", "n"), "float32"):
+ return R.matmul(a, b)
+
+
+def test_base_py_module_multiple_symbolic_dims():
+ """Test BasePyModule with multiple symbolic dimensions."""
+ bpm = MatrixModuleSymbolic(device=tvm.cpu(0), target="llvm")
+
+ # Test Relax function with multiple symbolic dims
+ a = np.random.randn(2, 3).astype("float32")
+ b = np.random.randn(3, 4).astype("float32")
+ out = bpm.matmul_relax(a, b)
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ expected = np.matmul(a, b)
+ np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6)
+
+ # Test TIR function with multiple symbolic dims
+ # Use concrete shapes for TIR function to avoid constraint issues
+ out_sinfo = relax.TensorStructInfo((2, 4), "float32")
+ out_tir = bpm.call_tir("matmul_tir", [a, b], out_sinfo)
+ out_tir_np = out_tir if isinstance(out_tir, np.ndarray) else
out_tir.numpy()
+ np.testing.assert_allclose(out_tir_np, expected, rtol=1e-6, atol=1e-6)
+
+
+def test_base_py_module_call_dps_packed_symbolic():
+ """Test call_dps_packed with symbolic shapes."""
+ try:
+ # Register a simple test function
+ @tvm.register_global_func("test_add_packed")
+ def test_add_packed(a, b, out):
+ """Add two tensors element-wise."""
+ a_np = a.numpy()
+ b_np = b.numpy()
+ result = a_np + b_np
+ out[:] = result
+
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ a = np.random.randn(5).astype("float32")
+ b = np.random.randn(5).astype("float32")
+
+ n = tir.Var("n", "int64")
+ out_sinfo = relax.TensorStructInfo((n,), "float32")
+
+ out = bpm.call_dps_packed("test_add_packed", [a, b], out_sinfo)
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ np.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6)
+
+ except AttributeError as e:
+ pytest.skip(f"call_dps_packed test requires register_global_func: {e}")
+
+
+def test_base_py_module_call_dps_packed_multiple_args():
+ """Test call_dps_packed with multiple arguments and symbolic shapes."""
+ try:
+ # Register a function that takes multiple arguments
+ @tvm.register_global_func("test_matmul_packed")
+ def test_matmul_packed(a, b, out):
+ """Matrix multiplication."""
+ a_np = a.numpy()
+ b_np = b.numpy()
+ result = np.matmul(a_np, b_np)
+ out[:] = result
+
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ a = np.random.randn(2, 3).astype("float32")
+ b = np.random.randn(3, 4).astype("float32")
+
+ out_sinfo = relax.TensorStructInfo((2, 4), "float32")
+
+ out = bpm.call_dps_packed("test_matmul_packed", [a, b], out_sinfo)
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ expected = np.matmul(a, b)
+ np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6)
+
+ except AttributeError as e:
+ pytest.skip(f"call_dps_packed test requires register_global_func: {e}")
+
+
+def test_base_py_module_call_dps_packed_scalar_args():
+ """Test call_dps_packed with scalar arguments and symbolic shapes."""
+ try:
+ # Register a function that takes scalar arguments
+ @tvm.register_global_func("test_add_scalar_packed")
+ def test_add_scalar_packed(x, scalar, out):
+ """Add scalar to tensor."""
+ x_np = x.numpy()
+ if hasattr(scalar, "numpy"):
+ scalar_val = scalar.numpy()
+ else:
+ scalar_val = scalar
+ result = x_np + scalar_val
+ out[:] = result
+
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ x = np.random.randn(4).astype("float32")
+ scalar = 2.5
+
+ n = tir.Var("n", "int64")
+ out_sinfo = relax.TensorStructInfo((n,), "float32")
+
+ out = bpm.call_dps_packed("test_add_scalar_packed", [x, scalar],
out_sinfo)
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ expected = x + scalar
+ np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6)
+
+ except AttributeError as e:
+ pytest.skip(f"call_dps_packed test requires register_global_func: {e}")
+
+
+def test_infer_concrete_shape_from_pytorch_tensors():
+ """Test shape inference from PyTorch tensors (if available)."""
+ try:
+ import torch
+ except ImportError:
+ pytest.skip("PyTorch not available")
+
+ mod = _make_module()
+ bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm")
+
+ n = tir.Var("n", "int64")
+ m = tir.Var("m", "int64")
+ sym_shape = [n, m]
+
+ x_torch = torch.zeros((3, 4), dtype=torch.float32)
+ inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_torch])
+ assert inferred == [3, 4]
+
+
+def test_base_py_module_relax_with_pytorch_tensors():
+ """Test Relax functions with PyTorch tensors and symbolic shapes."""
+ try:
+ import torch
+ except ImportError:
+ pytest.skip("PyTorch not available")
+
+ bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm")
+
+ a_torch = torch.randn(5, dtype=torch.float32)
+ b_torch = torch.randn(5, dtype=torch.float32)
+
+ out = bpm.main_relax(a_torch, b_torch)
+ out_np = out if isinstance(out, np.ndarray) else out.numpy()
+ expected = a_torch.numpy() + b_torch.numpy()
+ np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()