This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e229bda76f [Docs] Add tutorial for mixing Python/PyTorch with TVM
using BasePyModule (#18947)
e229bda76f is described below
commit e229bda76faff035a19fcdc515be51059ed4957b
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Mar 29 15:26:49 2026 -0400
[Docs] Add tutorial for mixing Python/PyTorch with TVM using BasePyModule
(#18947)
This pr add a new tutorial `mix_python_and_tvm_with_pymodule.py`
demonstrating how to use `BasePyModule` to mix Python/PyTorch functions
with TIR and Relax in a single IRModule.
## Tutorial Contents (7 steps)
- **Step 1**: `I.pyfunc` + `call_tir` basics, DLPack zero-copy
conversion, `show()`
- **Step 2**: Debugging with `print` in pyfuncs — inspect intermediate
tensors without compiling
- **Step 3**: Realistic pipeline combining `call_tir`,
`call_dps_packed`, and Python/PyTorch in one forward pass
- **Step 4**: Dynamic function registration via `add_python_function`
- **Step 5**: `RelaxToPyFuncConverter` — convert Relax IR to PyTorch at
different compilation stages (before and after passes) to verify
numerical correctness
- **Step 6**: `R.call_py_func` — cross-level calls between compiled
Relax VM and Python functions
- **Step 7**: Symbolic shapes for dynamic batch sizes
This pr also fixs a bug in `BasePyModule._compile_functions` where
modules without Relax functions would incorrectly attempt Relax VM
compilation, producing spurious warnings like `Failed to compile Relax
VM: 'NoneType' object has no attribute 'kind'`.
---
.../tutorials/mix_python_and_tvm_with_pymodule.py | 468 +++++++++++++++++++++
python/tvm/relax/base_py_module.py | 9 +-
2 files changed, 469 insertions(+), 8 deletions(-)
diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
new file mode 100644
index 0000000000..91d1cb9c26
--- /dev/null
+++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
@@ -0,0 +1,468 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: E402
+
+"""
+.. _mix_python_and_tvm:
+
+Mix Python/PyTorch with TVM Using BasePyModule
+===============================================
+In a typical TVM workflow, you write an ``IRModule``, compile it, and load the
compiled artifact
+into a ``VirtualMachine`` to run. This means **you cannot test or debug
anything until the entire
+module compiles successfully**. If a single op is unsupported, the whole
pipeline is blocked.
+
+``BasePyModule`` solves this by letting Python functions, TIR kernels, and
Relax functions coexist
+in one module. TIR and Relax functions are JIT-compiled on instantiation,
Python functions run
+as-is, and tensors move between TVM and PyTorch via zero-copy DLPack. This
enables:
+
+- **Incremental development**: get a model running with Python fallbacks
first, then replace them
+ with TVM ops one by one.
+- **Easy debugging**: insert ``print`` in Python functions to inspect
intermediate tensors — no
+ need to compile the whole module first.
+- **Verification at any compilation stage**: convert Relax IR back to PyTorch
to check numerical
+ correctness before and after optimization passes.
+- **Hybrid execution**: let the compiled VM call back into Python for ops that
are hard to
+ express in TIR or Relax.
+
+This tutorial walks through the full workflow step by step.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 1
+"""
+
+######################################################################
+# Preparation
+# -----------
+
+import os
+
+try:
+ import torch
+ import torch.nn.functional as F
+except ImportError:
+ torch = None
+
+import tvm
+from tvm import relax
+from tvm.relax.base_py_module import BasePyModule
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tirx as T
+
+IS_IN_CI = os.getenv("CI", "").lower() == "true"
+HAS_TORCH = torch is not None
+RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI
+
+
+######################################################################
+# Step 1: Your First Hybrid Module
+# ----------------------------------
+# The core idea: decorate a class with ``@I.ir_module``, inherit from
``BasePyModule``, and use
+# three decorators for three kinds of functions:
+#
+# - ``@T.prim_func`` — low-level TIR kernel (JIT-compiled on instantiation)
+# - ``@R.function`` — high-level Relax graph (JIT-compiled on instantiation)
+# - ``@I.pyfunc`` — plain Python (runs as-is, can use any Python library)
+#
+# ``call_tir`` bridges Python and TIR: it converts PyTorch tensors to TVM
NDArrays via DLPack
+# (zero-copy), allocates the output buffer, calls the compiled kernel, and
converts back.
+
+if RUN_EXAMPLE:
+
+ @I.ir_module
+ class MyFirstModule(BasePyModule):
+ @T.prim_func
+ def add_tir(
+ A: T.Buffer((4,), "float32"),
+ B: T.Buffer((4,), "float32"),
+ C: T.Buffer((4,), "float32"),
+ ):
+ for i in range(4):
+ C[i] = A[i] + B[i]
+
+ @I.pyfunc
+ def forward(self, x, y):
+ """Takes PyTorch tensors, calls TIR, returns PyTorch tensors."""
+ x_tvm = self._convert_pytorch_to_tvm(x)
+ y_tvm = self._convert_pytorch_to_tvm(y)
+ result = self.call_tir(
+ self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,),
"float32")
+ )
+ return self._convert_tvm_to_pytorch(result)
+
+ # TIR functions are JIT-compiled at instantiation
+ mod = MyFirstModule(device=tvm.cpu(0))
+
+ x = torch.tensor([1.0, 2.0, 3.0, 4.0])
+ y = torch.tensor([10.0, 20.0, 30.0, 40.0])
+ result = mod.forward(x, y)
+
+ print("forward(x, y) =", result)
+ assert torch.allclose(result, x + y)
+
+ # show() prints TVMScript including Python functions (shown as ExternFunc)
+ mod.show()
+
+ # list_functions() shows what is available in the module
+ print("Available functions:", mod.list_functions())
+
+
+######################################################################
+# Step 2: Debugging — The Main Selling Point
+# ---------------------------------------------
+# Traditional ML compilers treat computation graphs as monolithic blobs. You
cannot inspect
+# intermediate tensor values without compiling the entire module. With
``@I.pyfunc``, debugging
+# is as simple as adding a ``print`` statement. You can also make quick edits
and re-run
+# immediately — no recompilation needed.
+
+if RUN_EXAMPLE:
+
+ @I.ir_module
+ class DebugModule(BasePyModule):
+ @T.prim_func
+ def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
+ n = T.int32()
+ A = T.match_buffer(var_A, (n, 4), "float32")
+ B = T.match_buffer(var_B, (4, 3), "float32")
+ C = T.match_buffer(var_C, (n, 3), "float32")
+ for i, j, k in T.grid(n, 3, 4):
+ with T.sblock("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @I.pyfunc
+ def forward(self, x, weights):
+ # Inspect input
+ print(f" [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}")
+
+ # Run TIR matmul
+ x_tvm = self._convert_pytorch_to_tvm(x)
+ w_tvm = self._convert_pytorch_to_tvm(weights)
+ out = self.call_tir(
+ self.matmul_tir,
+ [x_tvm, w_tvm],
+ out_sinfo=R.Tensor((x.shape[0], 3), "float32"),
+ )
+ logits = self._convert_tvm_to_pytorch(out)
+
+ # Inspect intermediate value — impossible with a compiled-only
workflow
+ print(f" [DEBUG] logits shape: {logits.shape}, "
+ f"min: {logits.min():.4f}, max: {logits.max():.4f}")
+
+ result = F.softmax(logits, dim=-1)
+
+ # Verify output
+ print(f" [DEBUG] probs sum: {result.sum(dim=-1)}")
+ return result
+
+ mod = DebugModule(device=tvm.cpu(0))
+
+ print("Running with debug prints:")
+ probs = mod.forward(torch.randn(2, 4), torch.randn(4, 3))
+ assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5)
+
+######################################################################
+# This is the key benefit: "debugging is as simple as inserting a print
statement.
+# Users can also make quick, manual edits to Python functions and immediately
observe the
+# results." No compilation cycle, no VM loading — just Python.
+
+
+######################################################################
+# Step 3: A Realistic Pipeline — Python, TIR, and Packed Functions
+# -------------------------------------------------------------------
+# Real models combine many kinds of operations. This step builds a mini
inference pipeline using
+# three different calling conventions:
+#
+# - ``call_tir``: call a compiled TIR kernel
+# - ``call_dps_packed``: call a TVM packed function (e.g., a third-party
library binding)
+# - Direct Python: call any PyTorch function
+#
+# ``call_dps_packed`` is useful for calling functions registered via
``tvm.register_global_func``
+# — for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions.
+
+if RUN_EXAMPLE:
+
+ # Register a packed function (simulating an external library binding)
+ @tvm.register_global_func("my_bias_add", override=True)
+ def my_bias_add(x, bias, out):
+ """Packed function: adds bias to each row of x."""
+ import numpy as np
+
+ x_np = x.numpy()
+ b_np = bias.numpy()
+ out_np = x_np + b_np
+ out[:] = out_np
+
+ @I.ir_module
+ class PipelineModule(BasePyModule):
+ @T.prim_func
+ def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
+ A = T.match_buffer(var_A, (2, 4), "float32")
+ B = T.match_buffer(var_B, (4, 3), "float32")
+ C = T.match_buffer(var_C, (2, 3), "float32")
+ for i, j, k in T.grid(2, 3, 4):
+ with T.sblock("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @I.pyfunc
+ def forward(self, x, weights, bias):
+ # 1. TIR matmul
+ x_tvm = self._convert_pytorch_to_tvm(x)
+ w_tvm = self._convert_pytorch_to_tvm(weights)
+ h = self.call_tir(
+ self.matmul_tir, [x_tvm, w_tvm],
+ out_sinfo=R.Tensor((2, 3), "float32"),
+ )
+ h_pt = self._convert_tvm_to_pytorch(h)
+
+ # 2. Packed function for bias add (simulating an external library)
+ h_biased = self.call_dps_packed(
+ "my_bias_add", [h_pt, bias],
+ out_sinfo=R.Tensor((2, 3), "float32"),
+ )
+
+ # 3. Python/PyTorch activation
+ return F.relu(h_biased)
+
+ mod = PipelineModule(device=tvm.cpu(0))
+
+ x = torch.randn(2, 4)
+ w = torch.randn(4, 3)
+ b = torch.randn(3)
+ result = mod.forward(x, w, b)
+
+ expected = F.relu(x @ w + b)
+ print("Pipeline result:", result)
+ print("Expected: ", expected)
+ assert torch.allclose(result, expected, atol=1e-4)
+
+
+######################################################################
+# Step 4: Relax-to-Python Converter — Verify at Any Compilation Stage
+# ----------------------------------------------------------------------
+# Both Relax functions and Python functions describe computational graphs. The
+# ``RelaxToPyFuncConverter`` converts Relax IR into equivalent PyTorch code by
mapping
+# Relax operators to their PyTorch counterparts (e.g., ``R.nn.relu`` →
``F.relu``).
+#
+# A key feature: **this conversion can happen at any stage of compilation**.
+# You can convert early (right after import) or late (after optimization
passes have
+# transformed the IR), and compare the output against a PyTorch reference to
catch bugs.
+
+if RUN_EXAMPLE:
+ from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter
+
+ # A simple Relax module: matmul + bias + relu (a dense layer)
+ @I.ir_module
+ class DenseLayer:
+ @T.prim_func
+ def bias_add_tir(var_x: T.handle, var_b: T.handle, var_out: T.handle):
+ x = T.match_buffer(var_x, (2, 4), "float32")
+ b = T.match_buffer(var_b, (4,), "float32")
+ out = T.match_buffer(var_out, (2, 4), "float32")
+ for i, j in T.grid(2, 4):
+ out[i, j] = x[i, j] + b[j]
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 4), "float32"),
+ w: R.Tensor((4, 4), "float32"),
+ b: R.Tensor((4,), "float32"),
+ ) -> R.Tensor((2, 4), "float32"):
+ h = R.matmul(x, w)
+ cls = DenseLayer
+ h_bias = R.call_tir(
+ cls.bias_add_tir, (h, b),
+ out_sinfo=R.Tensor((2, 4), "float32"),
+ )
+ return R.nn.relu(h_bias)
+
+ # --- Stage 1: Convert BEFORE optimization ---
+ converter = RelaxToPyFuncConverter(DenseLayer)
+ converted_early = converter.convert(["main"])
+
+ x = torch.randn(2, 4)
+ w = torch.randn(4, 4)
+ b = torch.randn(4)
+
+ py_result_early = converted_early.pyfuncs["main"](x, w, b)
+ expected = F.relu(x @ w + b)
+
+ print("Before optimization:")
+ print(" Converted result:", py_result_early)
+ print(" PyTorch expected:", expected)
+ assert torch.allclose(py_result_early, expected, atol=1e-5)
+
+ # --- Stage 2: Apply a pass, then convert AFTER optimization ---
+ # Run CanonicalizeBindings to clean up the IR, then convert again
+ # to verify the pass did not break numerical correctness.
+ optimized_mod = relax.transform.CanonicalizeBindings()(DenseLayer)
+
+ converter_late = RelaxToPyFuncConverter(optimized_mod)
+ converted_late = converter_late.convert(["main"])
+
+ py_result_late = converted_late.pyfuncs["main"](x, w, b)
+
+ print("\nAfter CanonicalizeBindings pass:")
+ print(" Converted result:", py_result_late)
+ print(" Still matches: ",
+ torch.allclose(py_result_late, expected, atol=1e-5))
+ assert torch.allclose(py_result_late, expected, atol=1e-5)
+
+
+######################################################################
+# Step 5: R.call_py_func — Python Callbacks in Compiled IR
+# -----------------------------------------------------------
+# ``R.call_py_func`` embeds a Python function call directly inside Relax IR.
When the module
+# is compiled and run in the VM, everything else is optimized native code, but
the VM calls
+# back into Python for the specified ops.
+#
+# ``BasePyModule`` supports cross-level calls in both directions: Relax
functions can invoke
+# Python functions, and Python functions can invoke TIR/Relax functions. Data
flows between
+# them via DLPack with minimal overhead.
+#
+# Use case: your model has a custom op (e.g., a special normalization or a
sampling step)
+# that is complex to implement in TIR. Compile everything else, and let that
one op stay
+# in Python.
+
+if RUN_EXAMPLE:
+
+ @I.ir_module
+ class HybridVMModule(BasePyModule):
+ @I.pyfunc
+ def silu(self, x):
+ """SiLU/Swish activation — using Python as fallback."""
+ return torch.sigmoid(x) * x
+
+ @I.pyfunc
+ def layer_norm(self, x):
+ """LayerNorm — another Python fallback."""
+ return F.layer_norm(x, x.shape[-1:])
+
+ @R.function
+ def main(
+ x: R.Tensor((4, 8), "float32"),
+ ) -> R.Tensor((4, 8), "float32"):
+ # The VM calls back into Python for these two ops
+ h = R.call_py_func(
+ "layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32")
+ )
+ out = R.call_py_func(
+ "silu", (h,), out_sinfo=R.Tensor((4, 8), "float32")
+ )
+ return out
+
+ mod = HybridVMModule(device=tvm.cpu(0))
+ x = torch.randn(4, 8)
+
+ # call_py_func is also callable from Python directly
+ result = mod.call_py_func("layer_norm", [x])
+ result = mod.call_py_func("silu", [result])
+
+ ln = F.layer_norm(x, x.shape[-1:])
+ expected = torch.sigmoid(ln) * ln
+ print("call_py_func result:", result)
+ assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5)
+
+
+######################################################################
+# Step 6: Cross-Level Calls and Symbolic Shapes
+# ------------------------------------------------
+# ``BasePyModule`` is designed for **cross-level interoperability**: Python
functions can call
+# TIR and Relax functions, and Relax functions can call Python functions. We
have already seen:
+#
+# - Python → TIR via ``call_tir`` (Steps 1–3)
+# - Python → packed function via ``call_dps_packed`` (Step 3)
+# - Relax → Python via ``R.call_py_func`` (Step 5)
+#
+# The missing piece: **Python calling a compiled Relax function directly**.
When a module
+# contains ``@R.function``, it is JIT-compiled into a Relax VM. You can call
it from Python
+# just like any other method — the module auto-converts PyTorch tensors to TVM
and back.
+#
+# This step also shows **symbolic shapes**: TIR and Relax functions can
declare dynamic
+# dimensions (e.g., ``"n"``). ``BasePyModule`` infers concrete shapes from the
actual input
+# tensors at call time, so the same module handles different sizes without
recompilation.
+
+if RUN_EXAMPLE:
+
+ @I.ir_module
+ class DynamicModule(BasePyModule):
+ @T.prim_func
+ def scale_tir(var_x: T.handle, var_out: T.handle):
+ n = T.int64()
+ x = T.match_buffer(var_x, (n,), "float32")
+ out = T.match_buffer(var_out, (n,), "float32")
+ for i in T.serial(n):
+ out[i] = x[i] * T.float32(2.0)
+
+ @R.function
+ def add_relax(
+ x: R.Tensor(("n",), "float32"),
+ y: R.Tensor(("n",), "float32"),
+ ) -> R.Tensor(("n",), "float32"):
+ return R.add(x, y)
+
+ mod = DynamicModule(device=tvm.cpu(0), target="llvm")
+
+ # Inspect what the module contains
+ print("Functions:", mod.list_functions())
+
+ # Python → Relax: call the compiled Relax function directly with PyTorch
tensors
+ a5 = torch.randn(5)
+ b5 = torch.randn(5)
+ out5 = mod.add_relax(a5, b5)
+ print("add_relax(len=5):", out5)
+
+ # Same module, different size — symbolic shapes handle this automatically
+ a10 = torch.randn(10)
+ b10 = torch.randn(10)
+ out10 = mod.add_relax(a10, b10)
+ print("add_relax(len=10):", out10)
+
+ # Python → TIR with symbolic output shape
+ n = T.int64()
+ x7 = torch.randn(7)
+ scaled = mod.call_tir(
+ "scale_tir", [x7], relax.TensorStructInfo((n,), "float32")
+ )
+ print("scale_tir(len=7):", scaled)
+ assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5)
+
+
+######################################################################
+# Summary
+# -------
+# Cross-level call summary:
+#
+# - **Python → TIR**: ``call_tir()`` (Steps 1, 2, 3, 6)
+# - **Python → packed function**: ``call_dps_packed()`` (Step 3)
+# - **Python → Relax**: call ``@R.function`` as a method (Step 6)
+# - **Relax → Python**: ``R.call_py_func()`` in compiled VM (Step 5)
+#
+# The workflow in practice:
+#
+# 1. Import a model → some ops unsupported → use ``@I.pyfunc`` as Python
fallbacks
+# 2. Get it running end-to-end with ``BasePyModule``
+# 3. Debug by inserting ``print`` in pyfuncs — inspect intermediate tensors
instantly
+# 4. Use ``RelaxToPyFuncConverter`` to verify correctness after each
optimization pass
+# 5. Gradually replace Python fallbacks with TIR/Relax implementations
+# 6. Use ``R.call_py_func`` for ops that must stay in Python even after
compilation
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 7840e5b3b4..67b6761633 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -146,14 +146,7 @@ class BasePyModule:
except Exception as error:
print(f"Warning: Failed to compile one or more TIR functions:
{error}")
- relax_mod = tvm.IRModule(
- {
- gv: func
- for gv, func in self.ir_mod.functions_items()
- if isinstance(func, relax.Function)
- }
- )
- if relax_mod:
+ if self.relax_func_names:
try:
exec_mod = tvm.compile(self.ir_mod, target=self.target)
self.relax_vm = relax.VirtualMachine(exec_mod, self.device)