This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e229bda76f [Docs] Add tutorial for mixing Python/PyTorch with TVM 
using BasePyModule (#18947)
e229bda76f is described below

commit e229bda76faff035a19fcdc515be51059ed4957b
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Mar 29 15:26:49 2026 -0400

    [Docs] Add tutorial for mixing Python/PyTorch with TVM using BasePyModule 
(#18947)
    
    This pr add a new tutorial `mix_python_and_tvm_with_pymodule.py`
    demonstrating how to use `BasePyModule` to mix Python/PyTorch functions
    with TIR and Relax in a single IRModule.
    ## Tutorial Contents (7 steps)
    - **Step 1**: `I.pyfunc` + `call_tir` basics, DLPack zero-copy
    conversion, `show()`
    - **Step 2**: Debugging with `print` in pyfuncs — inspect intermediate
    tensors without compiling
    - **Step 3**: Realistic pipeline combining `call_tir`,
    `call_dps_packed`, and Python/PyTorch in one forward pass
    - **Step 4**: Dynamic function registration via `add_python_function`
    - **Step 5**: `RelaxToPyFuncConverter` — convert Relax IR to PyTorch at
    different compilation stages (before and after passes) to verify
    numerical correctness
    - **Step 6**: `R.call_py_func` — cross-level calls between compiled
    Relax VM and Python functions
    - **Step 7**: Symbolic shapes for dynamic batch sizes
    This pr also fixs a bug in `BasePyModule._compile_functions` where
    modules without Relax functions would incorrectly attempt Relax VM
    compilation, producing spurious warnings like `Failed to compile Relax
    VM: 'NoneType' object has no attribute 'kind'`.
---
 .../tutorials/mix_python_and_tvm_with_pymodule.py  | 468 +++++++++++++++++++++
 python/tvm/relax/base_py_module.py                 |   9 +-
 2 files changed, 469 insertions(+), 8 deletions(-)

diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py 
b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
new file mode 100644
index 0000000000..91d1cb9c26
--- /dev/null
+++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
@@ -0,0 +1,468 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: E402
+
+"""
+.. _mix_python_and_tvm:
+
+Mix Python/PyTorch with TVM Using BasePyModule
+===============================================
+In a typical TVM workflow, you write an ``IRModule``, compile it, and load the 
compiled artifact
+into a ``VirtualMachine`` to run. This means **you cannot test or debug 
anything until the entire
+module compiles successfully**. If a single op is unsupported, the whole 
pipeline is blocked.
+
+``BasePyModule`` solves this by letting Python functions, TIR kernels, and 
Relax functions coexist
+in one module. TIR and Relax functions are JIT-compiled on instantiation, 
Python functions run
+as-is, and tensors move between TVM and PyTorch via zero-copy DLPack. This 
enables:
+
+- **Incremental development**: get a model running with Python fallbacks 
first, then replace them
+  with TVM ops one by one.
+- **Easy debugging**: insert ``print`` in Python functions to inspect 
intermediate tensors — no
+  need to compile the whole module first.
+- **Verification at any compilation stage**: convert Relax IR back to PyTorch 
to check numerical
+  correctness before and after optimization passes.
+- **Hybrid execution**: let the compiled VM call back into Python for ops that 
are hard to
+  express in TIR or Relax.
+
+This tutorial walks through the full workflow step by step.
+
+.. contents:: Table of Contents
+    :local:
+    :depth: 1
+"""
+
+######################################################################
+# Preparation
+# -----------
+
+import os
+
+try:
+    import torch
+    import torch.nn.functional as F
+except ImportError:
+    torch = None
+
+import tvm
+from tvm import relax
+from tvm.relax.base_py_module import BasePyModule
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tirx as T
+
+IS_IN_CI = os.getenv("CI", "").lower() == "true"
+HAS_TORCH = torch is not None
+RUN_EXAMPLE = HAS_TORCH and not IS_IN_CI
+
+
+######################################################################
+# Step 1: Your First Hybrid Module
+# ----------------------------------
+# The core idea: decorate a class with ``@I.ir_module``, inherit from 
``BasePyModule``, and use
+# three decorators for three kinds of functions:
+#
+# - ``@T.prim_func`` — low-level TIR kernel (JIT-compiled on instantiation)
+# - ``@R.function`` — high-level Relax graph (JIT-compiled on instantiation)
+# - ``@I.pyfunc`` — plain Python (runs as-is, can use any Python library)
+#
+# ``call_tir`` bridges Python and TIR: it converts PyTorch tensors to TVM 
NDArrays via DLPack
+# (zero-copy), allocates the output buffer, calls the compiled kernel, and 
converts back.
+
+if RUN_EXAMPLE:
+
+    @I.ir_module
+    class MyFirstModule(BasePyModule):
+        @T.prim_func
+        def add_tir(
+            A: T.Buffer((4,), "float32"),
+            B: T.Buffer((4,), "float32"),
+            C: T.Buffer((4,), "float32"),
+        ):
+            for i in range(4):
+                C[i] = A[i] + B[i]
+
+        @I.pyfunc
+        def forward(self, x, y):
+            """Takes PyTorch tensors, calls TIR, returns PyTorch tensors."""
+            x_tvm = self._convert_pytorch_to_tvm(x)
+            y_tvm = self._convert_pytorch_to_tvm(y)
+            result = self.call_tir(
+                self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((4,), 
"float32")
+            )
+            return self._convert_tvm_to_pytorch(result)
+
+    # TIR functions are JIT-compiled at instantiation
+    mod = MyFirstModule(device=tvm.cpu(0))
+
+    x = torch.tensor([1.0, 2.0, 3.0, 4.0])
+    y = torch.tensor([10.0, 20.0, 30.0, 40.0])
+    result = mod.forward(x, y)
+
+    print("forward(x, y) =", result)
+    assert torch.allclose(result, x + y)
+
+    # show() prints TVMScript including Python functions (shown as ExternFunc)
+    mod.show()
+
+    # list_functions() shows what is available in the module
+    print("Available functions:", mod.list_functions())
+
+
+######################################################################
+# Step 2: Debugging — The Main Selling Point
+# ---------------------------------------------
+# Traditional ML compilers treat computation graphs as monolithic blobs. You 
cannot inspect
+# intermediate tensor values without compiling the entire module. With 
``@I.pyfunc``, debugging
+# is as simple as adding a ``print`` statement. You can also make quick edits 
and re-run
+# immediately — no recompilation needed.
+
+if RUN_EXAMPLE:
+
+    @I.ir_module
+    class DebugModule(BasePyModule):
+        @T.prim_func
+        def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
+            n = T.int32()
+            A = T.match_buffer(var_A, (n, 4), "float32")
+            B = T.match_buffer(var_B, (4, 3), "float32")
+            C = T.match_buffer(var_C, (n, 3), "float32")
+            for i, j, k in T.grid(n, 3, 4):
+                with T.sblock("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @I.pyfunc
+        def forward(self, x, weights):
+            # Inspect input
+            print(f"  [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}")
+
+            # Run TIR matmul
+            x_tvm = self._convert_pytorch_to_tvm(x)
+            w_tvm = self._convert_pytorch_to_tvm(weights)
+            out = self.call_tir(
+                self.matmul_tir,
+                [x_tvm, w_tvm],
+                out_sinfo=R.Tensor((x.shape[0], 3), "float32"),
+            )
+            logits = self._convert_tvm_to_pytorch(out)
+
+            # Inspect intermediate value — impossible with a compiled-only 
workflow
+            print(f"  [DEBUG] logits shape: {logits.shape}, "
+                  f"min: {logits.min():.4f}, max: {logits.max():.4f}")
+
+            result = F.softmax(logits, dim=-1)
+
+            # Verify output
+            print(f"  [DEBUG] probs sum: {result.sum(dim=-1)}")
+            return result
+
+    mod = DebugModule(device=tvm.cpu(0))
+
+    print("Running with debug prints:")
+    probs = mod.forward(torch.randn(2, 4), torch.randn(4, 3))
+    assert torch.allclose(probs.sum(dim=-1), torch.ones(2), atol=1e-5)
+
+######################################################################
+# This is the key benefit: "debugging is as simple as inserting a print 
statement.
+# Users can also make quick, manual edits to Python functions and immediately 
observe the
+# results." No compilation cycle, no VM loading — just Python.
+
+
+######################################################################
+# Step 3: A Realistic Pipeline — Python, TIR, and Packed Functions
+# -------------------------------------------------------------------
+# Real models combine many kinds of operations. This step builds a mini 
inference pipeline using
+# three different calling conventions:
+#
+# - ``call_tir``: call a compiled TIR kernel
+# - ``call_dps_packed``: call a TVM packed function (e.g., a third-party 
library binding)
+# - Direct Python: call any PyTorch function
+#
+# ``call_dps_packed`` is useful for calling functions registered via 
``tvm.register_global_func``
+# — for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions.
+
+if RUN_EXAMPLE:
+
+    # Register a packed function (simulating an external library binding)
+    @tvm.register_global_func("my_bias_add", override=True)
+    def my_bias_add(x, bias, out):
+        """Packed function: adds bias to each row of x."""
+        import numpy as np
+
+        x_np = x.numpy()
+        b_np = bias.numpy()
+        out_np = x_np + b_np
+        out[:] = out_np
+
+    @I.ir_module
+    class PipelineModule(BasePyModule):
+        @T.prim_func
+        def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
+            A = T.match_buffer(var_A, (2, 4), "float32")
+            B = T.match_buffer(var_B, (4, 3), "float32")
+            C = T.match_buffer(var_C, (2, 3), "float32")
+            for i, j, k in T.grid(2, 3, 4):
+                with T.sblock("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @I.pyfunc
+        def forward(self, x, weights, bias):
+            # 1. TIR matmul
+            x_tvm = self._convert_pytorch_to_tvm(x)
+            w_tvm = self._convert_pytorch_to_tvm(weights)
+            h = self.call_tir(
+                self.matmul_tir, [x_tvm, w_tvm],
+                out_sinfo=R.Tensor((2, 3), "float32"),
+            )
+            h_pt = self._convert_tvm_to_pytorch(h)
+
+            # 2. Packed function for bias add (simulating an external library)
+            h_biased = self.call_dps_packed(
+                "my_bias_add", [h_pt, bias],
+                out_sinfo=R.Tensor((2, 3), "float32"),
+            )
+
+            # 3. Python/PyTorch activation
+            return F.relu(h_biased)
+
+    mod = PipelineModule(device=tvm.cpu(0))
+
+    x = torch.randn(2, 4)
+    w = torch.randn(4, 3)
+    b = torch.randn(3)
+    result = mod.forward(x, w, b)
+
+    expected = F.relu(x @ w + b)
+    print("Pipeline result:", result)
+    print("Expected:       ", expected)
+    assert torch.allclose(result, expected, atol=1e-4)
+
+
+######################################################################
+# Step 4: Relax-to-Python Converter — Verify at Any Compilation Stage
+# ----------------------------------------------------------------------
+# Both Relax functions and Python functions describe computational graphs. The
+# ``RelaxToPyFuncConverter`` converts Relax IR into equivalent PyTorch code by 
mapping
+# Relax operators to their PyTorch counterparts (e.g., ``R.nn.relu`` → 
``F.relu``).
+#
+# A key feature: **this conversion can happen at any stage of compilation**.
+# You can convert early (right after import) or late (after optimization 
passes have
+# transformed the IR), and compare the output against a PyTorch reference to 
catch bugs.
+
+if RUN_EXAMPLE:
+    from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter
+
+    # A simple Relax module: matmul + bias + relu (a dense layer)
+    @I.ir_module
+    class DenseLayer:
+        @T.prim_func
+        def bias_add_tir(var_x: T.handle, var_b: T.handle, var_out: T.handle):
+            x = T.match_buffer(var_x, (2, 4), "float32")
+            b = T.match_buffer(var_b, (4,), "float32")
+            out = T.match_buffer(var_out, (2, 4), "float32")
+            for i, j in T.grid(2, 4):
+                out[i, j] = x[i, j] + b[j]
+
+        @R.function
+        def main(
+            x: R.Tensor((2, 4), "float32"),
+            w: R.Tensor((4, 4), "float32"),
+            b: R.Tensor((4,), "float32"),
+        ) -> R.Tensor((2, 4), "float32"):
+            h = R.matmul(x, w)
+            cls = DenseLayer
+            h_bias = R.call_tir(
+                cls.bias_add_tir, (h, b),
+                out_sinfo=R.Tensor((2, 4), "float32"),
+            )
+            return R.nn.relu(h_bias)
+
+    # --- Stage 1: Convert BEFORE optimization ---
+    converter = RelaxToPyFuncConverter(DenseLayer)
+    converted_early = converter.convert(["main"])
+
+    x = torch.randn(2, 4)
+    w = torch.randn(4, 4)
+    b = torch.randn(4)
+
+    py_result_early = converted_early.pyfuncs["main"](x, w, b)
+    expected = F.relu(x @ w + b)
+
+    print("Before optimization:")
+    print("  Converted result:", py_result_early)
+    print("  PyTorch expected:", expected)
+    assert torch.allclose(py_result_early, expected, atol=1e-5)
+
+    # --- Stage 2: Apply a pass, then convert AFTER optimization ---
+    # Run CanonicalizeBindings to clean up the IR, then convert again
+    # to verify the pass did not break numerical correctness.
+    optimized_mod = relax.transform.CanonicalizeBindings()(DenseLayer)
+
+    converter_late = RelaxToPyFuncConverter(optimized_mod)
+    converted_late = converter_late.convert(["main"])
+
+    py_result_late = converted_late.pyfuncs["main"](x, w, b)
+
+    print("\nAfter CanonicalizeBindings pass:")
+    print("  Converted result:", py_result_late)
+    print("  Still matches:   ",
+          torch.allclose(py_result_late, expected, atol=1e-5))
+    assert torch.allclose(py_result_late, expected, atol=1e-5)
+
+
+######################################################################
+# Step 5: R.call_py_func — Python Callbacks in Compiled IR
+# -----------------------------------------------------------
+# ``R.call_py_func`` embeds a Python function call directly inside Relax IR. 
When the module
+# is compiled and run in the VM, everything else is optimized native code, but 
the VM calls
+# back into Python for the specified ops.
+#
+# ``BasePyModule`` supports cross-level calls in both directions: Relax 
functions can invoke
+# Python functions, and Python functions can invoke TIR/Relax functions. Data 
flows between
+# them via DLPack with minimal overhead.
+#
+# Use case: your model has a custom op (e.g., a special normalization or a 
sampling step)
+# that is complex to implement in TIR. Compile everything else, and let that 
one op stay
+# in Python.
+
+if RUN_EXAMPLE:
+
+    @I.ir_module
+    class HybridVMModule(BasePyModule):
+        @I.pyfunc
+        def silu(self, x):
+            """SiLU/Swish activation — using Python as fallback."""
+            return torch.sigmoid(x) * x
+
+        @I.pyfunc
+        def layer_norm(self, x):
+            """LayerNorm — another Python fallback."""
+            return F.layer_norm(x, x.shape[-1:])
+
+        @R.function
+        def main(
+            x: R.Tensor((4, 8), "float32"),
+        ) -> R.Tensor((4, 8), "float32"):
+            # The VM calls back into Python for these two ops
+            h = R.call_py_func(
+                "layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32")
+            )
+            out = R.call_py_func(
+                "silu", (h,), out_sinfo=R.Tensor((4, 8), "float32")
+            )
+            return out
+
+    mod = HybridVMModule(device=tvm.cpu(0))
+    x = torch.randn(4, 8)
+
+    # call_py_func is also callable from Python directly
+    result = mod.call_py_func("layer_norm", [x])
+    result = mod.call_py_func("silu", [result])
+
+    ln = F.layer_norm(x, x.shape[-1:])
+    expected = torch.sigmoid(ln) * ln
+    print("call_py_func result:", result)
+    assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5)
+
+
+######################################################################
+# Step 6: Cross-Level Calls and Symbolic Shapes
+# ------------------------------------------------
+# ``BasePyModule`` is designed for **cross-level interoperability**: Python 
functions can call
+# TIR and Relax functions, and Relax functions can call Python functions. We 
have already seen:
+#
+# - Python → TIR via ``call_tir`` (Steps 1–3)
+# - Python → packed function via ``call_dps_packed`` (Step 3)
+# - Relax → Python via ``R.call_py_func`` (Step 5)
+#
+# The missing piece: **Python calling a compiled Relax function directly**. 
When a module
+# contains ``@R.function``, it is JIT-compiled into a Relax VM. You can call 
it from Python
+# just like any other method — the module auto-converts PyTorch tensors to TVM 
and back.
+#
+# This step also shows **symbolic shapes**: TIR and Relax functions can 
declare dynamic
+# dimensions (e.g., ``"n"``). ``BasePyModule`` infers concrete shapes from the 
actual input
+# tensors at call time, so the same module handles different sizes without 
recompilation.
+
+if RUN_EXAMPLE:
+
+    @I.ir_module
+    class DynamicModule(BasePyModule):
+        @T.prim_func
+        def scale_tir(var_x: T.handle, var_out: T.handle):
+            n = T.int64()
+            x = T.match_buffer(var_x, (n,), "float32")
+            out = T.match_buffer(var_out, (n,), "float32")
+            for i in T.serial(n):
+                out[i] = x[i] * T.float32(2.0)
+
+        @R.function
+        def add_relax(
+            x: R.Tensor(("n",), "float32"),
+            y: R.Tensor(("n",), "float32"),
+        ) -> R.Tensor(("n",), "float32"):
+            return R.add(x, y)
+
+    mod = DynamicModule(device=tvm.cpu(0), target="llvm")
+
+    # Inspect what the module contains
+    print("Functions:", mod.list_functions())
+
+    # Python → Relax: call the compiled Relax function directly with PyTorch 
tensors
+    a5 = torch.randn(5)
+    b5 = torch.randn(5)
+    out5 = mod.add_relax(a5, b5)
+    print("add_relax(len=5):", out5)
+
+    # Same module, different size — symbolic shapes handle this automatically
+    a10 = torch.randn(10)
+    b10 = torch.randn(10)
+    out10 = mod.add_relax(a10, b10)
+    print("add_relax(len=10):", out10)
+
+    # Python → TIR with symbolic output shape
+    n = T.int64()
+    x7 = torch.randn(7)
+    scaled = mod.call_tir(
+        "scale_tir", [x7], relax.TensorStructInfo((n,), "float32")
+    )
+    print("scale_tir(len=7):", scaled)
+    assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5)
+
+
+######################################################################
+# Summary
+# -------
+# Cross-level call summary:
+#
+# - **Python → TIR**: ``call_tir()`` (Steps 1, 2, 3, 6)
+# - **Python → packed function**: ``call_dps_packed()`` (Step 3)
+# - **Python → Relax**: call ``@R.function`` as a method (Step 6)
+# - **Relax → Python**: ``R.call_py_func()`` in compiled VM (Step 5)
+#
+# The workflow in practice:
+#
+# 1. Import a model → some ops unsupported → use ``@I.pyfunc`` as Python 
fallbacks
+# 2. Get it running end-to-end with ``BasePyModule``
+# 3. Debug by inserting ``print`` in pyfuncs — inspect intermediate tensors 
instantly
+# 4. Use ``RelaxToPyFuncConverter`` to verify correctness after each 
optimization pass
+# 5. Gradually replace Python fallbacks with TIR/Relax implementations
+# 6. Use ``R.call_py_func`` for ops that must stay in Python even after 
compilation
diff --git a/python/tvm/relax/base_py_module.py 
b/python/tvm/relax/base_py_module.py
index 7840e5b3b4..67b6761633 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -146,14 +146,7 @@ class BasePyModule:
             except Exception as error:
                 print(f"Warning: Failed to compile one or more TIR functions: 
{error}")
 
-        relax_mod = tvm.IRModule(
-            {
-                gv: func
-                for gv, func in self.ir_mod.functions_items()
-                if isinstance(func, relax.Function)
-            }
-        )
-        if relax_mod:
+        if self.relax_func_names:
             try:
                 exec_mod = tvm.compile(self.ir_mod, target=self.target)
                 self.relax_vm = relax.VirtualMachine(exec_mod, self.device)

Reply via email to