MasterJH5574 commented on code in PR #18229:
URL: https://github.com/apache/tvm/pull/18229#discussion_r2299000434


##########
python/tvm/relax/base_py_module.py:
##########
@@ -0,0 +1,367 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""BasePyModule: Base class for IRModules with Python function support."""
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import tvm
+from tvm import relax, tir
+from tvm.ir import IRModule
+from tvm.runtime import Device, NDArray, PackedFunc
+from tvm.target import Target
+
+try:
+    from torch.utils.dlpack import to_dlpack as to_dlpack_legacy
+except ImportError:
+    to_dlpack_legacy = None
+
+
+class BasePyModule:
+    """Base class that allows Python functions in IRModule with DLPack 
conversion.
+
+    This class provides the infrastructure for:
+    1. JIT compilation of TIR and Relax functions.
+    2. DLPack-based conversion between PyTorch tensors and TVM NDArrays.
+    3. Wrapping Relax functions for easy Python calling.
+    4. Cross-function calls between Python, TIR, and Relax functions.
+
+    Only IRModules that inherit from this class are allowed to contain Python 
functions.
+    """
+
+    def __init__(
+        self,
+        ir_mod: IRModule,
+        device: Device,
+        target: Optional[Target] = None,
+    ):
+        """Initialize BasePyModule with JIT compilation and DLPack 
conversion."""
+        self.device = device
+        self.ir_mod = ir_mod
+
+        # Delegate IRModule operations
+        self.functions = ir_mod.functions
+        self.attrs = ir_mod.attrs
+        self.global_infos = ir_mod.global_infos
+        self.__getitem__ = ir_mod.__getitem__
+        self.__setitem__ = ir_mod.__setitem__
+        self.functions_items = ir_mod.functions_items
+        self.with_attr = ir_mod.with_attr
+        self.get_attr = ir_mod.get_attr
+        self.update_global_info = ir_mod.update_global_info
+
+        def _getattr_python_function(name: str) -> Any:
+            """Support direct attribute access to funcs and IRModule 
methods."""
+            if name in self.pyfuncs:
+                return self.pyfuncs[name]
+            if name in self.compiled_tir_funcs:
+                return self.compiled_tir_funcs[name]
+            if self.relax_vm and name in self.relax_func_names:
+                try:
+                    return self.relax_vm[name]
+                except AttributeError:  # More specific exception
+                    return None
+            if hasattr(self.ir_mod, name):
+                return getattr(self.ir_mod, name)
+            raise AttributeError(f"'{type(self).__name__}' object has no 
attribute '{name}'")
+
+        self.__getattr__ = _getattr_python_function
+
+        self.compiled_tir_funcs: Dict[str, PackedFunc] = {}
+        self.extern_funcs: Dict[str, PackedFunc] = {}
+        self.tir_func_names: List[str] = []
+        self.relax_func_names: List[str] = []
+        self.relax_vm: Optional[relax.VirtualMachine] = None
+        self.pyfuncs: Dict[str, Any] = {}
+
+        if target is None:
+            target = Target.from_device(device)
+        elif isinstance(target, str):
+            target = Target(target)
+        self.target = target
+
+        self._collect_function_names()
+        self._compile_functions()
+        self._wrap_tir_functions()
+        self._wrap_relax_functions()
+
+    def _collect_function_names(self):
+        """Collect names of TIR and Relax functions from IRModule."""
+        for global_var, func in self.ir_mod.functions_items():
+            if isinstance(func, tir.PrimFunc):
+                self.tir_func_names.append(global_var.name_hint)
+            elif isinstance(func, relax.Function):
+                self.relax_func_names.append(global_var.name_hint)
+
+    def _compile_functions(self):
+        """Compile TIR and Relax functions using JIT compilation."""
+        # Compile TIR functions first
+        tir_mod = tvm.IRModule(
+            {
+                gv: func
+                for gv, func in self.ir_mod.functions_items()
+                if isinstance(func, tir.PrimFunc)
+            }
+        )
+        if tir_mod:
+            try:
+                tir_exec_mod = tvm.compile(tir_mod, target=self.target)
+                for func_name in self.tir_func_names:
+                    self.compiled_tir_funcs[func_name] = 
tir_exec_mod[func_name]
+            except (tvm.TVMError, RuntimeError) as error:
+                print(f"Warning: Failed to compile one or more TIR functions: 
{error}")
+
+        relax_mod = tvm.IRModule(
+            {
+                gv: func
+                for gv, func in self.ir_mod.functions_items()
+                if isinstance(func, relax.Function)
+            }
+        )
+        if relax_mod:
+            try:
+                exec_mod = tvm.compile(self.ir_mod, target=self.target)
+                self.relax_vm = relax.VirtualMachine(exec_mod, self.device)
+            except (tvm.TVMError, RuntimeError) as error:
+                print(f"Warning: Failed to compile Relax VM: {error}")
+                self.relax_vm = None
+
+    def _wrap_tir_functions(self):
+        """Wrap TIR functions to make them accessible as instance 
attributes."""
+        for func_name, func in self.compiled_tir_funcs.items():
+            setattr(self, func_name, func)
+
+    def _wrap_relax_functions(self):
+        """Wrap Relax functions to be callable from Python with auto 
conversion."""
+        if self.relax_vm is None:
+            return
+
+        for func_name in self.relax_func_names:
+
+            def _create_relax_wrapper(name):
+                def wrapper(*args, **kwargs):
+                    """Wrapper for Relax function with automatic tensor 
conversion."""
+                    converted_args = self._convert_pytorch_to_tvm(list(args))
+                    converted_kwargs = {
+                        k: self._convert_pytorch_to_tvm(v) for k, v in 
kwargs.items()
+                    }
+                    result = self.relax_vm[name](*converted_args, 
**converted_kwargs)
+                    return self._convert_tvm_to_pytorch(result)
+
+                wrapper.__name__ = name
+                wrapper.__doc__ = f"Wrapped Relax function: {name}"
+                return wrapper
+
+            setattr(self, func_name, _create_relax_wrapper(func_name))
+
+    def call_tir(self, tir_func, args, out_sinfo):
+        """Call a TIR function with PyTorch tensors."""
+        # Try to get function name from different sources
+        if isinstance(tir_func, str):
+            func_name = tir_func
+        elif hasattr(tir_func, "name"):
+            func_name = tir_func.name
+        elif hasattr(tir_func, "__name__"):
+            func_name = tir_func.__name__
+        else:
+            # Try to find by function object reference
+            for name, func in self.compiled_tir_funcs.items():
+                if func == tir_func:
+                    func_name = name
+                    break
+            else:
+                func_name = None
+
+        if not func_name or func_name not in self.compiled_tir_funcs:
+            available_funcs = list(self.compiled_tir_funcs.keys())
+            raise ValueError(
+                f"Could not resolve or find compiled TIR function: {tir_func}. 
"
+                f"Available functions: {available_funcs}"
+            )
+        func = self.compiled_tir_funcs[func_name]
+
+        out = self._create_output_tensors(out_sinfo)
+        tvm_args = self._convert_pytorch_to_tvm(args)
+        tvm_out = self._convert_pytorch_to_tvm(out)
+
+        func(*tvm_args, *tvm_out)
+
+        result = self._convert_tvm_to_pytorch(tvm_out)
+        return result[0] if len(result) == 1 else result
+
+    def call_dps_packed(self, func_name: str, args, out_sinfo):
+        """Call a packed function with PyTorch tensors, converting TVM 
NDArrays via DLPack."""
+        if hasattr(self, func_name) and callable(getattr(self, func_name)):
+            return getattr(self, func_name)(*args)
+
+        if func_name not in self.extern_funcs:
+            try:
+                self.extern_funcs[func_name] = tvm.get_global_func(func_name)
+            except ValueError as error:
+                raise ValueError(
+                    f"Function '{func_name}' not found as a global function. "
+                    f"Please implement it as a method or register it."
+                ) from error
+        func = self.extern_funcs[func_name]
+
+        out = self._create_output_tensors(out_sinfo)
+        tvm_args = self._convert_pytorch_to_tvm(args)
+        tvm_out = self._convert_pytorch_to_tvm(out)
+        func(*tvm_args, *tvm_out)
+        result = self._convert_tvm_to_pytorch(tvm_out)
+        return result[0] if len(result) == 1 else result
+
+    def call_py_func(self, func_name: str, args):
+        """Call a Python function stored in the IRModule's pyfuncs."""
+        if func_name not in self.ir_mod.pyfuncs:
+            raise ValueError(f"Python function '{func_name}' not found in 
IRModule pyfuncs")
+        py_func = self.ir_mod.pyfuncs[func_name]
+        converted_args = self._convert_tvm_to_pytorch(args)
+        return py_func(*converted_args)
+
+    def _create_output_tensors(self, out_sinfo):
+        """Create output PyTorch tensors based on shape and type 
information."""
+        sinfo_list = out_sinfo if isinstance(out_sinfo, list) else [out_sinfo]
+        out_tensors = []
+        for sinfo in sinfo_list:
+            if hasattr(sinfo, "shape") and hasattr(sinfo, "dtype"):
+                shape = [int(val) for val in sinfo.shape]
+                torch_dtype = self._convert_tvm_dtype_to_torch(sinfo.dtype)
+                out_tensors.append(torch.empty(shape, dtype=torch_dtype))
+            else:
+                out_tensors.append(torch.empty((1,), dtype=torch.float32))
+        return out_tensors
+
+    def _convert_tvm_dtype_to_torch(self, tvm_dtype: str) -> torch.dtype:
+        """Convert TVM dtype string to PyTorch dtype."""
+        dtype_mapping = {
+            "float32": torch.float32,
+            "float64": torch.float64,
+            "int32": torch.int32,
+            "int64": torch.int64,
+            "bool": torch.bool,
+        }
+        return dtype_mapping.get(str(tvm_dtype), torch.float32)
+
+    def _convert_pytorch_to_tvm(
+        self, tensors: Union[Any, List[Any], Tuple[Any, ...]]
+    ) -> Union[NDArray, List[NDArray]]:
+        """Convert PyTorch tensors to TVM NDArrays using DLPack."""
+        if isinstance(tensors, (list, tuple)):
+            return [self._convert_single_pytorch_to_tvm(t) for t in tensors]
+        return self._convert_single_pytorch_to_tvm(tensors)
+
+    def _convert_single_pytorch_to_tvm(self, tensor: Any) -> NDArray:
+        """Convert a single PyTorch tensor to TVM NDArray with robust 
fallbacks."""
+        if isinstance(tensor, NDArray):
+            return tensor
+        if isinstance(tensor, torch.Tensor):
+            # 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
+            try:
+                dlpack = torch.to_dlpack(tensor)
+                return tvm.nd.from_dlpack(dlpack)
+            except (AttributeError, ValueError):
+                pass  # Fall through to the next method
+            # 2. Try legacy `torch.utils.dlpack.to_dlpack`
+            if to_dlpack_legacy:
+                try:
+                    dlpack = to_dlpack_legacy(tensor)
+                    return tvm.nd.from_dlpack(dlpack)
+                except (AttributeError, ValueError) as error_legacy:
+                    print(
+                        f"Warning: Legacy DLPack conversion failed 
({error_legacy}), "
+                        f"using numpy fallback."
+                    )
+            # 3. If all DLPack methods fail, use numpy fallback
+            numpy_array = tensor.detach().cpu().numpy()
+            return tvm.nd.array(numpy_array, device=self.device)
+
+        # For other types (like scalars, lists), convert to numpy first
+        try:
+            numpy_array = np.array(tensor, dtype=np.float32)
+            return tvm.nd.array(numpy_array, device=self.device)
+        except (TypeError, ValueError) as error:
+            raise TypeError(
+                f"Unsupported type for conversion to TVM NDArray: 
{type(tensor)}"
+            ) from error
+
+    def _convert_tvm_to_pytorch(
+        self, tvm_arrays: Union[Any, List[Any]]
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        """Convert TVM NDArrays to PyTorch tensors using DLPack."""
+        if isinstance(tvm_arrays, (list, tuple)):
+            return [self._convert_single_tvm_to_pytorch(arr) for arr in 
tvm_arrays]
+        return self._convert_single_tvm_to_pytorch(tvm_arrays)
+
+    def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> torch.Tensor:
+        """Convert a single TVM NDArray to PyTorch tensor using DLPack."""
+        if isinstance(tvm_array, torch.Tensor):
+            return tvm_array
+        if not isinstance(tvm_array, NDArray):
+            return torch.tensor(tvm_array)
+        try:
+            dlpack = tvm_array.to_dlpack()
+            return torch.from_dlpack(dlpack)
+        except (tvm.TVMError, RuntimeError) as error:
+            print(f"Warning: DLPack conversion from TVM failed ({error}), 
using numpy fallback")
+            numpy_array = tvm_array.asnumpy()

Review Comment:
   ```suggestion
               numpy_array = tvm_array.numpy()
   ```
   `asnumpy` is formally removed by the latest commits. Let's use `.numpy()`.  
Likely we also need to rebase this PR onto the latest main branch.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to