This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a54af64872 [Relax][Backend] Implement R.call_py_func operator for
calling Python functions from compiled TVM (#18326)
a54af64872 is described below
commit a54af64872c68913309541f6f30e75da3921ef77
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Sep 21 23:36:18 2025 -0400
[Relax][Backend] Implement R.call_py_func operator for calling Python
functions from compiled TVM (#18326)
This PR implements the `R.call_py_func` operator that allows compiled
TVM Relax modules to call Python functions at runtime. This enables
integration between TVM's compiled code and Python through a
robust VM backend implementation.
#### Simple Usage with BasePyModule
```python
@I.ir_module
class MyModule(BasePyModule):
@I.pyfunc
def torch_relu(self, x):
return torch.relu(x)
@R.function
def forward(x: R.Tensor((10,), "float32")) -> R.Tensor((10,),
"float32"):
return R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,),
"float32"))
```
#### Direct VM Backend Usage (Manual)
```python
# Manually register Python function with VM backend
register_func = tvm.get_global_func("vm.builtin.register_py_func")
register_func("my_func", my_python_function)
# Use in Relax function (compiled to VM backend)
@R.function
def test(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.call_py_func("my_func", (x,), out_sinfo=R.Tensor((5,),
"float32"))
# Manual cleanup (required for direct VM backend usage)
clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry")
clear_func()
```
---
python/tvm/relax/base_py_module.py | 38 +++++++++
src/relax/backend/vm/codegen_vm.cc | 1 -
src/relax/backend/vm/lower_runtime_builtin.cc | 20 +++++
src/runtime/vm/builtin.cc | 74 +++++++++++++++++
tests/python/relax/test_base_py_module_printer.py | 96 ++++++++++-------------
tests/python/relax/test_relax_operators.py | 76 ++++++++++++++++++
6 files changed, 248 insertions(+), 57 deletions(-)
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 52f813dc6b..7a790d28a7 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -45,6 +45,14 @@ class BasePyModule:
Only IRModules that inherit from this class are allowed to contain Python
functions.
"""
+ def __del__(self):
+ """Clean up registered Python functions on module destruction."""
+ try:
+ clear_func =
tvm.get_global_func("vm.builtin.clear_py_func_registry")
+ clear_func()
+ except (ValueError, AttributeError):
+ pass
+
def __init__(
self,
ir_mod: IRModule,
@@ -100,6 +108,7 @@ class BasePyModule:
self._compile_functions()
self._wrap_tir_functions()
self._wrap_relax_functions()
+ self._register_python_functions()
def _collect_function_names(self):
"""Collect names of TIR and Relax functions from IRModule."""
@@ -177,6 +186,35 @@ class BasePyModule:
setattr(self, func_name, _create_relax_wrapper(func_name))
+ def _register_python_functions(self):
+ """Register Python functions with the VM runtime for call_py_func
support."""
+ if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs:
+ return
+
+ try:
+ register_py_func =
tvm.get_global_func("vm.builtin.register_py_func")
+ except ValueError:
+ return
+
+ for func_name, py_func in self.ir_mod.pyfuncs.items():
+
+ def create_py_func_wrapper(name, original_func):
+ def wrapper(*args, **kwargs):
+ converted_args = [self._convert_tvm_to_pytorch(arg) for
arg in args]
+ converted_kwargs = {
+ k: self._convert_tvm_to_pytorch(v) for k, v in
kwargs.items()
+ }
+
+ result = original_func(self, *converted_args,
**converted_kwargs)
+
+ return self._convert_pytorch_to_tvm(result)
+
+ wrapper.__name__ = name
+ return wrapper
+
+ wrapped_func = create_py_func_wrapper(func_name, py_func)
+ register_py_func(func_name, wrapped_func)
+
def call_tir(self, tir_func, args, out_sinfo):
"""Call a TIR function with PyTorch tensors."""
# Try to get function name from different sources
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 96dac05cb6..e2d9b5b068 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -368,7 +368,6 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const
Expr&)> {
builder_->EmitCall(func, args, dst_reg);
}
-
void EmitNormalCall(const Call& call_node, RegName dst_reg) {
Instruction::Arg func = VisitExpr(call_node->op);
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc
b/src/relax/backend/vm/lower_runtime_builtin.cc
index d52155c615..71b8413e98 100644
--- a/src/relax/backend/vm/lower_runtime_builtin.cc
+++ b/src/relax/backend/vm/lower_runtime_builtin.cc
@@ -24,6 +24,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
+#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
@@ -52,6 +53,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return ShapeOf(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call);
+ } else if (call->op == call_py_func_op_) {
+ return CallPyFunc(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
@@ -139,6 +142,21 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return Call(builtin_tensor_to_shape_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
}
+ Expr CallPyFunc(const Call& call_node) {
+ ICHECK(call_node->args.size() == 2);
+ ICHECK(call_node->struct_info_.defined());
+
+ // Create tuple with function name and arguments tuple
+ ffi::Array<Expr> tuple_fields;
+ tuple_fields.push_back(call_node->args[0]); // function name
+ tuple_fields.push_back(call_node->args[1]); // arguments tuple
+ auto combined_tuple = Tuple(tuple_fields);
+
+ // Direct call to vm.builtin.call_py_func
+ return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs,
call_node->sinfo_args,
+ call_node->span);
+ }
+
Expr ToDevice(const Call& call_node) {
// TODO(yongwww): replace ToVDeviceAttrs with related Expr
ICHECK(call_node->args.size() == 1);
@@ -198,6 +216,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
+ const Op& call_py_func_op_ = Op::Get("relax.call_py_func");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
@@ -216,6 +235,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"};
+ const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"};
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 362a7e4c89..41c011678e 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -34,6 +34,8 @@
#include <tvm/runtime/vm/bytecode.h>
#include <tvm/runtime/vm/vm.h>
+#include <unordered_map>
+
namespace tvm {
namespace runtime {
namespace vm {
@@ -430,6 +432,78 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
}
+//-------------------------------------
+// Python function call support
+//-------------------------------------
+
+// Global registry for Python functions
+static std::unordered_map<std::string, ffi::Function> py_func_registry;
+
+/*!
+ * \brief Clear the Python function registry on shutdown
+ */
+void ClearPyFuncRegistry() { py_func_registry.clear(); }
+
+/*!
+ * \brief Register a Python function for call_py_func
+ * \param name The function name
+ * \param func The Python function wrapped as ffi::Function
+ */
+void RegisterPyFunc(const std::string& name, ffi::Function func) {
py_func_registry[name] = func; }
+
+/*!
+ * \brief Get a registered Python function
+ * \param name The function name
+ * \return The Python function
+ */
+ffi::Function GetPyFunc(const std::string& name) {
+ auto it = py_func_registry.find(name);
+ if (it == py_func_registry.end()) {
+ LOG(FATAL) << "Python function '" << name << "' not found in registry";
+ }
+ return it->second;
+}
+
+/*!
+ * \brief Call a Python function from VM
+ * \param args The packed function arguments (tuple containing function name
and arguments)
+ * \param rv The return value
+ */
+void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) {
+ // args[0] should be a tuple containing (func_name, args_tuple)
+ if (args.size() != 1) {
+ LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)";
+ }
+
+ auto tuple_arg = args[0].cast<ffi::Array<ffi::Any>>();
+ if (tuple_arg.size() != 2) {
+ LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name,
args)";
+ }
+
+ // Get function name
+ std::string func_name = tuple_arg[0].cast<ffi::String>();
+
+ // Get arguments tuple
+ auto func_args = tuple_arg[1].cast<ffi::Array<ffi::Any>>();
+
+ // Look up Python function in registry
+ ffi::Function py_func = GetPyFunc(func_name);
+
+ // Call the Python function with the arguments
+ std::vector<ffi::AnyView> py_args_vec(func_args.begin(), func_args.end());
+ ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size());
+ py_func.CallPacked(py_args, rv);
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef()
+ .def_packed("vm.builtin.call_py_func", CallPyFunc)
+ .def("vm.builtin.register_py_func", RegisterPyFunc)
+ .def("vm.builtin.get_py_func", GetPyFunc)
+ .def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry);
+}
+
//-------------------------------------
// Builtin runtime operators.
//-------------------------------------
diff --git a/tests/python/relax/test_base_py_module_printer.py
b/tests/python/relax/test_base_py_module_printer.py
index 6e87174fda..c9d23a7465 100644
--- a/tests/python/relax/test_base_py_module_printer.py
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -760,43 +760,54 @@ def test_python_functions_in_irmodule():
pytest.fail("pyfuncs attribute not found in IRModule")
-def test_call_py_func_validation():
- """Test call_py_func validation and error handling."""
+def test_call_py_func_with_base_py_module():
+ """Test R.call_py_func with BasePyModule."""
import torch
+ import numpy as np
+ from tvm.relax.op import call_py_func
+ from tvm.relax.expr import StringImm
+ from tvm.relax import Var, TensorStructInfo
- @I.ir_module
- class ValidationTestModule(BasePyModule):
- """Test module for validation."""
+ # Test 1: Operator creation and basic properties
+ x = Var("x", TensorStructInfo((5,), "float32"))
+ y = Var("y", TensorStructInfo((5,), "float32"))
- @I.pyfunc
- def valid_func(self, x):
- """Valid Python function."""
- return x * 2
+ call_expr = call_py_func(StringImm("test_func"), (x, y),
out_sinfo=R.Tensor((5,), "float32"))
+ assert call_expr.op.name == "relax.call_py_func"
+ assert call_expr.args[0].value == "test_func"
+ assert len(call_expr.args) == 2
+
+ # Test 2: Compilation validation
+ try:
+ call_py_func(
+ "invalid",
+ (Var("x", TensorStructInfo((5,), "float32")),),
+ out_sinfo=R.Tensor((5,), "float32"),
+ )
+ assert False, "Should raise type error"
+ except Exception as e:
+ assert "Mismatched type" in str(e) or "Expected" in str(e)
+
+ # Test 3: Validation and error handling
+ @I.ir_module
+ class ValidationTestModule(BasePyModule):
@R.function
def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,),
"float32"):
- # This should cause a validation error
result = R.call_py_func("non_existent_func", (x,),
out_sinfo=R.Tensor((5,), "float32"))
return result
device = tvm.cpu()
module = ValidationTestModule(device)
- # Test that calling non-existent function raises error
x = torch.randn(5, dtype=torch.float32)
with pytest.raises(ValueError, match="Python function 'non_existent_func'
not found"):
module.call_py_func("non_existent_func", [x])
-
-def test_call_py_func_in_relax_function():
- """Test using call_py_func within Relax functions."""
- import torch
-
+ # Test 4: Using call_py_func within Relax functions
@I.ir_module
class RelaxCallPyFuncModule(BasePyModule):
- """Test module with call_py_func in Relax functions."""
-
@I.pyfunc
def torch_relu(self, x):
"""PyTorch ReLU implementation."""
@@ -809,9 +820,7 @@ def test_call_py_func_in_relax_function():
@R.function
def mixed_computation(x: R.Tensor((10,), "float32")) ->
R.Tensor((10,), "float32"):
- # Use Python function for ReLU
relu_result = R.call_py_func("torch_relu", (x,),
out_sinfo=R.Tensor((10,), "float32"))
- # Use Python function for softmax
final_result = R.call_py_func(
"torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,),
"float32")
)
@@ -820,7 +829,6 @@ def test_call_py_func_in_relax_function():
device = tvm.cpu()
module = RelaxCallPyFuncModule(device)
- # Test the mixed computation
x = torch.randn(10, dtype=torch.float32)
expected = torch.softmax(torch.relu(x), dim=0)
@@ -828,40 +836,16 @@ def test_call_py_func_in_relax_function():
relu_result = module.call_py_func("torch_relu", [x])
final_result = module.call_py_func("torch_softmax", [relu_result])
- assert torch.allclose(final_result, expected, atol=1e-5)
-
-
-def test_call_py_func_operator_creation():
- """Test R.call_py_func operator creation and basic properties."""
- from tvm.relax.op import call_py_func
- from tvm.relax.expr import StringImm
- from tvm.relax import Var, TensorStructInfo
-
- # Create variables
- x = Var("x", TensorStructInfo((5,), "float32"))
- y = Var("y", TensorStructInfo((5,), "float32"))
-
- # Create call_py_func call
- call_expr = call_py_func(StringImm("test_func"), (x, y),
out_sinfo=R.Tensor((5,), "float32"))
-
- # Verify operator properties
- assert call_expr.op.name == "relax.call_py_func"
- assert call_expr.args[0].value == "test_func"
- assert len(call_expr.args) == 2
-
+ # Convert to numpy for comparison
+ if isinstance(final_result, tvm.runtime.Tensor):
+ final_result_np = final_result.numpy()
+ else:
+ final_result_np = final_result
-def test_call_py_func_compilation_validation():
- """Test call_py_func compilation validation."""
- from tvm.relax.op import call_py_func
- from tvm.relax import Var, TensorStructInfo
+ if isinstance(expected, torch.Tensor):
+ expected_np = expected.numpy()
+ else:
+ expected_np = expected
- # Test operator parameter validation
- try:
- call_py_func(
- "invalid",
- (Var("x", TensorStructInfo((5,), "float32")),),
- out_sinfo=R.Tensor((5,), "float32"),
- )
- assert False, "Should raise type error"
- except Exception as e:
- assert "Mismatched type" in str(e) or "Expected" in str(e)
+ # Use numpy for comparison since we have numpy arrays
+ np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5,
atol=1e-5)
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 8558f6e911..897082dd79 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -409,6 +409,82 @@ def test_op_call_inplace_packed(exec_mode):
assert (result[1].numpy() == sum).all()
+def test_op_call_py_func(exec_mode):
+ """Test R.call_py_func operator functionality."""
+ import torch
+
+ def torch_relu(x):
+ if isinstance(x, tvm.runtime.Tensor):
+ x_torch = torch.from_numpy(x.numpy())
+ elif hasattr(x, "asnumpy"):
+ x_torch = torch.from_numpy(x.asnumpy())
+ else:
+ x_np = np.array(x)
+ if isinstance(x_np, tvm.runtime.Tensor):
+ x_torch = torch.from_numpy(x_np.numpy())
+ elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor):
+ x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np]))
+ if x_torch.ndim > 1:
+ x_torch = x_torch.flatten()
+ else:
+ x_torch = torch.from_numpy(x_np)
+ result = torch.relu(x_torch)
+ return tvm.runtime.tensor(result.numpy())
+
+ def torch_sigmoid(x):
+ if isinstance(x, tvm.runtime.Tensor):
+ x_torch = torch.from_numpy(x.numpy())
+ elif hasattr(x, "asnumpy"):
+ x_torch = torch.from_numpy(x.asnumpy())
+ else:
+ x_np = np.array(x)
+ if isinstance(x_np, tvm.runtime.Tensor):
+ x_torch = torch.from_numpy(x_np.numpy())
+ elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor):
+ x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np]))
+ if x_torch.ndim > 1:
+ x_torch = x_torch.flatten()
+ else:
+ x_torch = torch.from_numpy(x_np)
+ result = torch.sigmoid(x_torch)
+ return tvm.runtime.tensor(result.numpy())
+
+ register_func = tvm.get_global_func("vm.builtin.register_py_func")
+ register_func("torch_relu", torch_relu)
+ register_func("torch_sigmoid", torch_sigmoid)
+
+ @tvm.script.ir_module
+ class CallPyFuncTest:
+ @R.function
+ def simple_call(x: R.Tensor((3,), "float32")):
+ result = R.call_py_func(R.str("torch_relu"), (x,),
out_sinfo=R.Tensor((3,), "float32"))
+ return result
+
+ @R.function
+ def multiple_calls(x: R.Tensor((2,), "float32")):
+ y = R.call_py_func(R.str("torch_relu"), (x,),
out_sinfo=R.Tensor((2,), "float32"))
+ z = R.call_py_func(R.str("torch_sigmoid"), (y,),
out_sinfo=R.Tensor((2,), "float32"))
+ return z
+
+ np.random.seed(0)
+ x_data = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
+ x_tvm = tvm.runtime.tensor(x_data)
+
+ result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode)
+ expected = np.maximum(x_data, 0.0)
+ assert (result.numpy() == expected).all()
+
+ y_data = np.array([-0.5, 0.5], dtype=np.float32)
+ y_tvm = tvm.runtime.tensor(y_data)
+
+ result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm,
exec_mode=exec_mode)
+ expected2 = 1.0 / (1.0 + np.exp(-np.maximum(y_data, 0.0)))
+ assert (result2.numpy() == expected2).all()
+
+ clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry")
+ clear_func()
+
+
def test_op_to_device(exec_mode):
@tvm.script.ir_module
class CallToDevice: