This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4041f890ce [Relax] Introduce R.call_py_func operator for calling 
Python functions from Relax IR (#18313)
4041f890ce is described below

commit 4041f890ce4db1b6547a5e5bcadfc7ee24e1ec8e
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Sep 19 09:28:05 2025 -0400

    [Relax] Introduce R.call_py_func operator for calling Python functions from 
Relax IR (#18313)
    
    This PR allows calling Python functions directly from Relax IR,
    where integration between Relax computations and Python/PyTorch
    operations can be supported.
    
    ### Usage Example
    ```python
    @I.ir_module
    class MyModule(BasePyModule):
        @I.pyfunc
        def pytorch_add(self, x, y):
            return x + y
    
        @R.function
        def compute(x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")) 
-> R.Tensor((5,), "float32"):
            result = R.call_py_func("pytorch_add", (x, y), 
out_sinfo=R.Tensor((5,), "float32"))
            return result
    ```
---
 python/tvm/relax/base_py_module.py                |  11 +--
 python/tvm/relax/op/__init__.py                   |   1 +
 python/tvm/relax/op/base.py                       |  36 ++++++++
 python/tvm/script/ir_builder/relax/ir.py          |  54 +++++++++++
 src/relax/op/op.cc                                |  64 +++++++++++++
 tests/python/relax/test_base_py_module_printer.py | 107 ++++++++++++++++++++++
 6 files changed, 267 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/base_py_module.py 
b/python/tvm/relax/base_py_module.py
index a4464cc737..52f813dc6b 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -234,12 +234,11 @@ class BasePyModule:
         return out[0] if len(out) == 1 else out
 
     def call_py_func(self, func_name: str, args):
-        """Call a Python function stored in the IRModule's pyfuncs."""
-        if func_name not in self.ir_mod.pyfuncs:
-            raise ValueError(f"Python function '{func_name}' not found in 
IRModule pyfuncs")
-        py_func = self.ir_mod.pyfuncs[func_name]
-        converted_args = self._convert_tvm_to_pytorch(args)
-        return py_func(*converted_args)
+        """Call a Python function stored in the module's pyfuncs."""
+        if func_name not in self.pyfuncs:
+            raise ValueError(f"Python function '{func_name}' not found in 
module pyfuncs")
+        py_func = self.pyfuncs[func_name]
+        return py_func(self, *args)
 
     def _create_output_tensors(self, out_sinfo, in_args=None):
         # pylint: disable=import-outside-toplevel
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index fd3672368b..6ea8305eca 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -27,6 +27,7 @@ from .base import (
     call_dps_packed,
     call_inplace_packed,
     call_pure_packed,
+    call_py_func,
     call_tir,
     call_tir_inplace,
     call_tir_with_grad,
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index e77920d8de..e205abde30 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -304,6 +304,42 @@ def call_dps_packed(
     return _ffi_api.call_dps_packed(func, args, out_sinfo)  # type: ignore
 
 
+@args_converter.auto
+def call_py_func(
+    func_name: str,
+    args: Expr,
+    out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
+) -> Call:
+    """
+    Call a Python function and return the output.
+
+    Parameters
+    ----------
+    func_name : str
+        The name of the Python function to call. This should correspond to a 
function
+        in the IRModule's pyfuncs attribute.
+
+    args : Expr
+        The input arguments.
+
+    out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
+        The structure info of the call_py_func output.
+        It should be a single or a list of TensorStructInfo. Each one denotes 
the
+        structure info of a returned tensor.
+
+    Returns
+    -------
+    ret: Call
+        A call node for the call_py_func operator.
+    """
+    args = _wrap_inline_arg_tuple(args)
+
+    if not isinstance(out_sinfo, list):
+        out_sinfo = [out_sinfo]
+
+    return _ffi_api.call_py_func(func_name, args, out_sinfo)  # type: ignore
+
+
 @args_converter.auto
 def call_builtin_with_ctx(
     func: Union[str, Expr],
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index d28ff3430a..3fa735197a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -30,6 +30,7 @@ from tvm.relax import (
     Expr,
     ExternFunc,
     ShapeExpr,
+    StringImm,
     TupleGetItem,
     Var,
     VarBinding,
@@ -64,6 +65,7 @@ from tvm.relax.op import (
     call_dps_packed,
     call_inplace_packed,
     call_pure_packed,
+    call_py_func as _call_py_func,
     call_tir,
     call_tir_inplace,
     call_tir_with_grad,
@@ -451,6 +453,57 @@ def call_packed(
     return Call(op, args, attrs=attrs, sinfo_args=sinfo_args)
 
 
+@args_converter.auto
+def call_py_func(
+    py_func_name: py_str,
+    *args: Expr,
+    out_sinfo: Union[StructInfo, List[StructInfo]],
+) -> Call:
+    """Create a relax Call, which calls a Python function.
+
+    Parameters
+    ----------
+    py_func_name: str
+        The name of the Python function to call. This should correspond to a 
function
+        in the IRModule's pyfuncs attribute.
+    *args : Expr
+        The arguments.
+    out_sinfo: Union[StructInfo, List[StructInfo]]
+        The structure info of the call_py_func output.
+        It should be a single or a list of TensorStructInfo. Each one denotes 
the
+        structure info of a returned tensor.
+
+    Returns
+    -------
+    call: Call
+        The created Relax Call for call_py_func operator.
+    """
+    if isinstance(out_sinfo, py_tuple):  # type: ignore
+        out_sinfo = list(out_sinfo)
+    elif not isinstance(out_sinfo, list):
+        out_sinfo = [out_sinfo]
+
+    out_sinfo = [
+        (
+            sinfo()
+            if callable(sinfo)
+            else sinfo.asobject()
+            if isinstance(sinfo, ObjectConvertible)
+            else sinfo
+        )
+        for sinfo in out_sinfo
+    ]
+
+    # Convert string to StringImm
+    try:
+        func_name_imm = (
+            StringImm(py_func_name) if isinstance(py_func_name, py_str) else 
py_func_name
+        )
+    except (TypeError, ValueError, AttributeError):
+        func_name_imm = StringImm(py_func_name)
+    return _call_py_func(func_name_imm, args, out_sinfo)
+
+
 def _sinfo_arg_wrapper(func):
     """A wrapper to convert StructInfoProxies to StructInfo for builtin 
operators with sinfo_args"""
 
@@ -743,6 +796,7 @@ __all__ = [
     "call_tir_inplace",
     "call_tir_with_grad",
     "call_dps_packed",
+    "call_py_func",
     "call_builtin_with_ctx",
     "ceil",
     "clip",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index e15d874723..d91c19b63f 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -858,6 +858,70 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked);
 }
 
+// call_py_func
+
+StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->sinfo_args.size() != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "sinfo_args should have exact 1 output struct info.");
+  }
+  return call->sinfo_args[0];
+}
+
+void ValidateCallPyFunc(Call call) {
+  // Validate that the function name is a string literal
+  auto func_name = call->args[0];
+  CHECK(func_name->IsInstance<StringImmNode>())
+      << "Operation " << call->op << " expects the first argument to be a 
string literal "
+      << "specifying the Python function name. However, the first argument " 
<< func_name
+      << " is not a string literal.";
+
+  // Validate that args is a tuple
+  Expr arg_tuple = call->args[1];
+  CHECK(arg_tuple->struct_info_.as<TupleStructInfoNode>())
+      << "Operation " << call->op << " expects the second argument to be a 
tuple of relax Expr.  "
+      << "However, the second argument " << arg_tuple << " has struct info "
+      << arg_tuple->struct_info_ << ".";
+
+  CHECK(arg_tuple.as<TupleNode>() || arg_tuple.as<VarNode>())
+      << "Operation " << call->op << " must hold its arguments as an in-line 
tuple.  "
+      << "However, " << call << " has arguments " << arg_tuple
+      << ", which is neither an in-line tuple, "
+      << "nor a variable binding that may be normalized to an in-line tuple.";
+}
+
+TVM_REGISTER_OP("relax.call_py_func")
+    .set_num_inputs(2)
+    .add_argument("func_name", "StringImm", "The name of the Python function 
to call.")
+    .add_argument("args", "Tuple", "The input arguments.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallPyFunc)
+    .set_attr<FValidate>("FValidate", ValidateCallPyFunc)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeCallPyFunc(StringImm func_name, Tuple args, 
ffi::Array<TensorStructInfo> out_sinfo_list) {
+  for (const TensorStructInfo& sinfo : out_sinfo_list) {
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
+    CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined 
ShapeExpr as shape. "
+                               "However, one given structure info is "
+                            << sinfo;
+  }
+
+  StructInfo out_sinfo{nullptr};
+  if (out_sinfo_list.size() == 1) {
+    out_sinfo = out_sinfo_list[0];
+  } else {
+    out_sinfo = TupleStructInfo({out_sinfo_list.begin(), 
out_sinfo_list.end()});
+  }
+
+  static const Op& op = Op::Get("relax.call_py_func");
+  return Call(op, {func_name, args}, {}, {out_sinfo});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc);
+}
+
 // call builtin
 StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const 
BlockBuilder& ctx) {
   if (call->sinfo_args.size() == 0) {
diff --git a/tests/python/relax/test_base_py_module_printer.py 
b/tests/python/relax/test_base_py_module_printer.py
index 92c799f6cb..6e87174fda 100644
--- a/tests/python/relax/test_base_py_module_printer.py
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -758,3 +758,110 @@ def test_python_functions_in_irmodule():
         assert pyfuncs["multiply"].__name__ == "multiply"
     else:
         pytest.fail("pyfuncs attribute not found in IRModule")
+
+
+def test_call_py_func_validation():
+    """Test call_py_func validation and error handling."""
+    import torch
+
+    @I.ir_module
+    class ValidationTestModule(BasePyModule):
+        """Test module for validation."""
+
+        @I.pyfunc
+        def valid_func(self, x):
+            """Valid Python function."""
+            return x * 2
+
+        @R.function
+        def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), 
"float32"):
+            # This should cause a validation error
+            result = R.call_py_func("non_existent_func", (x,), 
out_sinfo=R.Tensor((5,), "float32"))
+            return result
+
+    device = tvm.cpu()
+    module = ValidationTestModule(device)
+
+    # Test that calling non-existent function raises error
+    x = torch.randn(5, dtype=torch.float32)
+
+    with pytest.raises(ValueError, match="Python function 'non_existent_func' 
not found"):
+        module.call_py_func("non_existent_func", [x])
+
+
+def test_call_py_func_in_relax_function():
+    """Test using call_py_func within Relax functions."""
+    import torch
+
+    @I.ir_module
+    class RelaxCallPyFuncModule(BasePyModule):
+        """Test module with call_py_func in Relax functions."""
+
+        @I.pyfunc
+        def torch_relu(self, x):
+            """PyTorch ReLU implementation."""
+            return torch.relu(x)
+
+        @I.pyfunc
+        def torch_softmax(self, x, dim=0):
+            """PyTorch softmax implementation."""
+            return torch.softmax(x, dim=dim)
+
+        @R.function
+        def mixed_computation(x: R.Tensor((10,), "float32")) -> 
R.Tensor((10,), "float32"):
+            # Use Python function for ReLU
+            relu_result = R.call_py_func("torch_relu", (x,), 
out_sinfo=R.Tensor((10,), "float32"))
+            # Use Python function for softmax
+            final_result = R.call_py_func(
+                "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), 
"float32")
+            )
+            return final_result
+
+    device = tvm.cpu()
+    module = RelaxCallPyFuncModule(device)
+
+    # Test the mixed computation
+    x = torch.randn(10, dtype=torch.float32)
+
+    expected = torch.softmax(torch.relu(x), dim=0)
+
+    relu_result = module.call_py_func("torch_relu", [x])
+    final_result = module.call_py_func("torch_softmax", [relu_result])
+
+    assert torch.allclose(final_result, expected, atol=1e-5)
+
+
+def test_call_py_func_operator_creation():
+    """Test R.call_py_func operator creation and basic properties."""
+    from tvm.relax.op import call_py_func
+    from tvm.relax.expr import StringImm
+    from tvm.relax import Var, TensorStructInfo
+
+    # Create variables
+    x = Var("x", TensorStructInfo((5,), "float32"))
+    y = Var("y", TensorStructInfo((5,), "float32"))
+
+    # Create call_py_func call
+    call_expr = call_py_func(StringImm("test_func"), (x, y), 
out_sinfo=R.Tensor((5,), "float32"))
+
+    # Verify operator properties
+    assert call_expr.op.name == "relax.call_py_func"
+    assert call_expr.args[0].value == "test_func"
+    assert len(call_expr.args) == 2
+
+
+def test_call_py_func_compilation_validation():
+    """Test call_py_func compilation validation."""
+    from tvm.relax.op import call_py_func
+    from tvm.relax import Var, TensorStructInfo
+
+    # Test operator parameter validation
+    try:
+        call_py_func(
+            "invalid",
+            (Var("x", TensorStructInfo((5,), "float32")),),
+            out_sinfo=R.Tensor((5,), "float32"),
+        )
+        assert False, "Should raise type error"
+    except Exception as e:
+        assert "Mismatched type" in str(e) or "Expected" in str(e)

Reply via email to