This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4041f890ce [Relax] Introduce R.call_py_func operator for calling
Python functions from Relax IR (#18313)
4041f890ce is described below
commit 4041f890ce4db1b6547a5e5bcadfc7ee24e1ec8e
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Sep 19 09:28:05 2025 -0400
[Relax] Introduce R.call_py_func operator for calling Python functions from
Relax IR (#18313)
This PR allows calling Python functions directly from Relax IR,
where integration between Relax computations and Python/PyTorch
operations can be supported.
### Usage Example
```python
@I.ir_module
class MyModule(BasePyModule):
@I.pyfunc
def pytorch_add(self, x, y):
return x + y
@R.function
def compute(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32"))
-> R.Tensor((5,), "float32"):
result = R.call_py_func("pytorch_add", (x, y),
out_sinfo=R.Tensor((5,), "float32"))
return result
```
---
python/tvm/relax/base_py_module.py | 11 +--
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/base.py | 36 ++++++++
python/tvm/script/ir_builder/relax/ir.py | 54 +++++++++++
src/relax/op/op.cc | 64 +++++++++++++
tests/python/relax/test_base_py_module_printer.py | 107 ++++++++++++++++++++++
6 files changed, 267 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index a4464cc737..52f813dc6b 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -234,12 +234,11 @@ class BasePyModule:
return out[0] if len(out) == 1 else out
def call_py_func(self, func_name: str, args):
- """Call a Python function stored in the IRModule's pyfuncs."""
- if func_name not in self.ir_mod.pyfuncs:
- raise ValueError(f"Python function '{func_name}' not found in
IRModule pyfuncs")
- py_func = self.ir_mod.pyfuncs[func_name]
- converted_args = self._convert_tvm_to_pytorch(args)
- return py_func(*converted_args)
+ """Call a Python function stored in the module's pyfuncs."""
+ if func_name not in self.pyfuncs:
+ raise ValueError(f"Python function '{func_name}' not found in
module pyfuncs")
+ py_func = self.pyfuncs[func_name]
+ return py_func(self, *args)
def _create_output_tensors(self, out_sinfo, in_args=None):
# pylint: disable=import-outside-toplevel
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index fd3672368b..6ea8305eca 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -27,6 +27,7 @@ from .base import (
call_dps_packed,
call_inplace_packed,
call_pure_packed,
+ call_py_func,
call_tir,
call_tir_inplace,
call_tir_with_grad,
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index e77920d8de..e205abde30 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -304,6 +304,42 @@ def call_dps_packed(
return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore
+@args_converter.auto
+def call_py_func(
+ func_name: str,
+ args: Expr,
+ out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
+) -> Call:
+ """
+ Call a Python function and return the output.
+
+ Parameters
+ ----------
+ func_name : str
+ The name of the Python function to call. This should correspond to a
function
+ in the IRModule's pyfuncs attribute.
+
+ args : Expr
+ The input arguments.
+
+ out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
+ The structure info of the call_py_func output.
+ It should be a single or a list of TensorStructInfo. Each one denotes
the
+ structure info of a returned tensor.
+
+ Returns
+ -------
+ ret: Call
+ A call node for the call_py_func operator.
+ """
+ args = _wrap_inline_arg_tuple(args)
+
+ if not isinstance(out_sinfo, list):
+ out_sinfo = [out_sinfo]
+
+ return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore
+
+
@args_converter.auto
def call_builtin_with_ctx(
func: Union[str, Expr],
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index d28ff3430a..3fa735197a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -30,6 +30,7 @@ from tvm.relax import (
Expr,
ExternFunc,
ShapeExpr,
+ StringImm,
TupleGetItem,
Var,
VarBinding,
@@ -64,6 +65,7 @@ from tvm.relax.op import (
call_dps_packed,
call_inplace_packed,
call_pure_packed,
+ call_py_func as _call_py_func,
call_tir,
call_tir_inplace,
call_tir_with_grad,
@@ -451,6 +453,57 @@ def call_packed(
return Call(op, args, attrs=attrs, sinfo_args=sinfo_args)
+@args_converter.auto
+def call_py_func(
+ py_func_name: py_str,
+ *args: Expr,
+ out_sinfo: Union[StructInfo, List[StructInfo]],
+) -> Call:
+ """Create a relax Call, which calls a Python function.
+
+ Parameters
+ ----------
+ py_func_name: str
+ The name of the Python function to call. This should correspond to a
function
+ in the IRModule's pyfuncs attribute.
+ *args : Expr
+ The arguments.
+ out_sinfo: Union[StructInfo, List[StructInfo]]
+ The structure info of the call_py_func output.
+ It should be a single or a list of TensorStructInfo. Each one denotes
the
+ structure info of a returned tensor.
+
+ Returns
+ -------
+ call: Call
+ The created Relax Call for call_py_func operator.
+ """
+ if isinstance(out_sinfo, py_tuple): # type: ignore
+ out_sinfo = list(out_sinfo)
+ elif not isinstance(out_sinfo, list):
+ out_sinfo = [out_sinfo]
+
+ out_sinfo = [
+ (
+ sinfo()
+ if callable(sinfo)
+ else sinfo.asobject()
+ if isinstance(sinfo, ObjectConvertible)
+ else sinfo
+ )
+ for sinfo in out_sinfo
+ ]
+
+ # Convert string to StringImm
+ try:
+ func_name_imm = (
+ StringImm(py_func_name) if isinstance(py_func_name, py_str) else
py_func_name
+ )
+ except (TypeError, ValueError, AttributeError):
+ func_name_imm = StringImm(py_func_name)
+ return _call_py_func(func_name_imm, args, out_sinfo)
+
+
def _sinfo_arg_wrapper(func):
"""A wrapper to convert StructInfoProxies to StructInfo for builtin
operators with sinfo_args"""
@@ -743,6 +796,7 @@ __all__ = [
"call_tir_inplace",
"call_tir_with_grad",
"call_dps_packed",
+ "call_py_func",
"call_builtin_with_ctx",
"ceil",
"clip",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index e15d874723..d91c19b63f 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -858,6 +858,70 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked);
}
+// call_py_func
+
+StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder&
ctx) {
+ if (call->sinfo_args.size() != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "sinfo_args should have exact 1 output struct info.");
+ }
+ return call->sinfo_args[0];
+}
+
+void ValidateCallPyFunc(Call call) {
+ // Validate that the function name is a string literal
+ auto func_name = call->args[0];
+ CHECK(func_name->IsInstance<StringImmNode>())
+ << "Operation " << call->op << " expects the first argument to be a
string literal "
+ << "specifying the Python function name. However, the first argument "
<< func_name
+ << " is not a string literal.";
+
+ // Validate that args is a tuple
+ Expr arg_tuple = call->args[1];
+ CHECK(arg_tuple->struct_info_.as<TupleStructInfoNode>())
+ << "Operation " << call->op << " expects the second argument to be a
tuple of relax Expr. "
+ << "However, the second argument " << arg_tuple << " has struct info "
+ << arg_tuple->struct_info_ << ".";
+
+ CHECK(arg_tuple.as<TupleNode>() || arg_tuple.as<VarNode>())
+ << "Operation " << call->op << " must hold its arguments as an in-line
tuple. "
+ << "However, " << call << " has arguments " << arg_tuple
+ << ", which is neither an in-line tuple, "
+ << "nor a variable binding that may be normalized to an in-line tuple.";
+}
+
+TVM_REGISTER_OP("relax.call_py_func")
+ .set_num_inputs(2)
+ .add_argument("func_name", "StringImm", "The name of the Python function
to call.")
+ .add_argument("args", "Tuple", "The input arguments.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallPyFunc)
+ .set_attr<FValidate>("FValidate", ValidateCallPyFunc)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeCallPyFunc(StringImm func_name, Tuple args,
ffi::Array<TensorStructInfo> out_sinfo_list) {
+ for (const TensorStructInfo& sinfo : out_sinfo_list) {
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined
ShapeExpr as shape. "
+ "However, one given structure info is "
+ << sinfo;
+ }
+
+ StructInfo out_sinfo{nullptr};
+ if (out_sinfo_list.size() == 1) {
+ out_sinfo = out_sinfo_list[0];
+ } else {
+ out_sinfo = TupleStructInfo({out_sinfo_list.begin(),
out_sinfo_list.end()});
+ }
+
+ static const Op& op = Op::Get("relax.call_py_func");
+ return Call(op, {func_name, args}, {}, {out_sinfo});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc);
+}
+
// call builtin
StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const
BlockBuilder& ctx) {
if (call->sinfo_args.size() == 0) {
diff --git a/tests/python/relax/test_base_py_module_printer.py
b/tests/python/relax/test_base_py_module_printer.py
index 92c799f6cb..6e87174fda 100644
--- a/tests/python/relax/test_base_py_module_printer.py
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -758,3 +758,110 @@ def test_python_functions_in_irmodule():
assert pyfuncs["multiply"].__name__ == "multiply"
else:
pytest.fail("pyfuncs attribute not found in IRModule")
+
+
+def test_call_py_func_validation():
+ """Test call_py_func validation and error handling."""
+ import torch
+
+ @I.ir_module
+ class ValidationTestModule(BasePyModule):
+ """Test module for validation."""
+
+ @I.pyfunc
+ def valid_func(self, x):
+ """Valid Python function."""
+ return x * 2
+
+ @R.function
+ def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,),
"float32"):
+ # This should cause a validation error
+ result = R.call_py_func("non_existent_func", (x,),
out_sinfo=R.Tensor((5,), "float32"))
+ return result
+
+ device = tvm.cpu()
+ module = ValidationTestModule(device)
+
+ # Test that calling non-existent function raises error
+ x = torch.randn(5, dtype=torch.float32)
+
+ with pytest.raises(ValueError, match="Python function 'non_existent_func'
not found"):
+ module.call_py_func("non_existent_func", [x])
+
+
+def test_call_py_func_in_relax_function():
+ """Test using call_py_func within Relax functions."""
+ import torch
+
+ @I.ir_module
+ class RelaxCallPyFuncModule(BasePyModule):
+ """Test module with call_py_func in Relax functions."""
+
+ @I.pyfunc
+ def torch_relu(self, x):
+ """PyTorch ReLU implementation."""
+ return torch.relu(x)
+
+ @I.pyfunc
+ def torch_softmax(self, x, dim=0):
+ """PyTorch softmax implementation."""
+ return torch.softmax(x, dim=dim)
+
+ @R.function
+ def mixed_computation(x: R.Tensor((10,), "float32")) ->
R.Tensor((10,), "float32"):
+ # Use Python function for ReLU
+ relu_result = R.call_py_func("torch_relu", (x,),
out_sinfo=R.Tensor((10,), "float32"))
+ # Use Python function for softmax
+ final_result = R.call_py_func(
+ "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,),
"float32")
+ )
+ return final_result
+
+ device = tvm.cpu()
+ module = RelaxCallPyFuncModule(device)
+
+ # Test the mixed computation
+ x = torch.randn(10, dtype=torch.float32)
+
+ expected = torch.softmax(torch.relu(x), dim=0)
+
+ relu_result = module.call_py_func("torch_relu", [x])
+ final_result = module.call_py_func("torch_softmax", [relu_result])
+
+ assert torch.allclose(final_result, expected, atol=1e-5)
+
+
+def test_call_py_func_operator_creation():
+ """Test R.call_py_func operator creation and basic properties."""
+ from tvm.relax.op import call_py_func
+ from tvm.relax.expr import StringImm
+ from tvm.relax import Var, TensorStructInfo
+
+ # Create variables
+ x = Var("x", TensorStructInfo((5,), "float32"))
+ y = Var("y", TensorStructInfo((5,), "float32"))
+
+ # Create call_py_func call
+ call_expr = call_py_func(StringImm("test_func"), (x, y),
out_sinfo=R.Tensor((5,), "float32"))
+
+ # Verify operator properties
+ assert call_expr.op.name == "relax.call_py_func"
+ assert call_expr.args[0].value == "test_func"
+ assert len(call_expr.args) == 2
+
+
+def test_call_py_func_compilation_validation():
+ """Test call_py_func compilation validation."""
+ from tvm.relax.op import call_py_func
+ from tvm.relax import Var, TensorStructInfo
+
+ # Test operator parameter validation
+ try:
+ call_py_func(
+ "invalid",
+ (Var("x", TensorStructInfo((5,), "float32")),),
+ out_sinfo=R.Tensor((5,), "float32"),
+ )
+ assert False, "Should raise type error"
+ except Exception as e:
+ assert "Mismatched type" in str(e) or "Expected" in str(e)