This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c852002334 [Relax] Add Relax to Python Function Converter (#18269)
c852002334 is described below

commit c8520023345876b3560dae3c4a477e5c4e8cbd0b
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Sep 8 13:47:01 2025 -0400

    [Relax] Add Relax to Python Function Converter (#18269)
    
    ### Overview
    This PR implements a Relax to Python Function Converter that transforms
    Relax functions into executable Python functions using PyTorch operations.
    This enables seamless conversion between TVM's Relax IR and Python/PyTorch
    environments, which provides enhanced debugging capabilities and leveraging
    existing PyTorch operator libraries for testing and deployment purposes.
    
    ### Key Feature
    - **High-level operator mapping**: Maps 60+ Relax operators to
    corresponding PyTorch APIs
    - **Special operation handling**: Supports `call_tir`, `call_dps_packed`,
    and Relax function calls with DLPack integration
    - **Symbolic shape support**: Handles symbolic shapes and dynamic tensor
    operations
    
    ### **Example**
    ```python
    from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter
    
    # Convert Relax functions to Python functions
    converter = RelaxToPyFuncConverter(ir_module)
    converted_ir_mod = converter.convert("my_function")
    
    # Execute converted function with PyTorch tensors
    result = converted_ir_mod.pyfuncs['my_function'](input_tensor)
    ```
---
 python/tvm/relax/relax_to_pyfunc_converter.py      | 1104 ++++++++++++++++++++
 .../python/relax/test_relax_to_pyfunc_converter.py |  866 +++++++++++++++
 2 files changed, 1970 insertions(+)

diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py 
b/python/tvm/relax/relax_to_pyfunc_converter.py
new file mode 100644
index 0000000000..3de27d78c8
--- /dev/null
+++ b/python/tvm/relax/relax_to_pyfunc_converter.py
@@ -0,0 +1,1104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax to Python Function Converter.
+
+This module provides functionality to convert Relax functions to Python 
functions
+that can be executed directly in Python/PyTorch environment.
+"""
+
+from typing import Any, Dict, List, Union
+
+import torch
+import torch.nn.functional as F
+
+import tvm
+from tvm import relax
+from tvm.ir import IRModule, Op
+
+
+class RelaxToPyFuncConverter:
+    """Converter that works with IRModule to convert Relax functions to Python 
functions.
+
+    This converter transforms Relax functions into Python functions that can 
be executed
+    directly in Python/PyTorch environment. The conversion maps Relax 
operators to
+    corresponding PyTorch APIs and handles special cases like call_tir and 
call_dps_packed.
+    """
+
+    def __init__(self, ir_module: IRModule):
+        """Initialize the converter with an IRModule.
+
+        Args:
+            ir_module: The IRModule containing Relax functions to convert
+        """
+        self.ir_module = ir_module
+        self.operator_map = self._get_op_map()
+        # Cache for RelaxExpressionConverter instances to avoid recreating them
+        self._converter_cache = {}
+        # Cache for operator mappings to avoid repeated lookups
+        self._op_cache = {}
+
+    def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule:
+        """Convert specified Relax functions to Python functions.
+
+        Args:
+            relax_function_names: Name(s) of Relax functions to convert
+
+        Returns:
+            Updated IRModule with converted Python functions stored in pyfuncs
+
+        Example:
+            >>> converter = RelaxToPyFuncConverter(ir_mod)
+            >>> # Convert a single function
+            >>> converted_ir_mod = converter.convert("my_relax_func")
+            >>> # Convert multiple functions
+            >>> converted_ir_mod = converter.convert(["func1", "func2"])
+        """
+        if isinstance(relax_function_names, str):
+            relax_function_names = [relax_function_names]
+
+        # Create a copy of the current IRModule
+        new_ir_mod = self.ir_module.clone()
+
+        # Initialize pyfuncs if not exists
+        if not hasattr(new_ir_mod, "pyfuncs"):
+            new_ir_mod.pyfuncs = {}
+
+        # Get Relax function names from IRModule
+        relax_func_names = []
+        for global_var, func in self.ir_module.functions_items():
+            if isinstance(func, relax.Function):
+                relax_func_names.append(global_var.name_hint)
+
+        # Convert each Relax function
+        for func_name in relax_function_names:
+            if func_name not in relax_func_names:
+                raise ValueError(f"Relax function '{func_name}' not found in 
IRModule")
+
+            # Get the Relax function
+            relax_func = None
+            for global_var, func in self.ir_module.functions_items():
+                if global_var.name_hint == func_name and isinstance(func, 
relax.Function):
+                    relax_func = func
+                    break
+
+            if relax_func is None:
+                raise ValueError(f"Could not find Relax function 
'{func_name}'")
+
+            # Convert to Python function
+            py_func = self._convert_relax_func_to_python(relax_func, func_name)
+
+            # Store in pyfuncs
+            new_ir_mod.pyfuncs[func_name] = py_func
+
+        return new_ir_mod
+
+    def _convert_relax_func_to_python(self, relax_func: relax.Function, 
func_name: str) -> callable:
+        """Convert a single Relax function to a Python function with 
caching."""
+        # Get function parameters
+        params = relax_func.params
+
+        # Create the Python function
+        def converted_function(*args, **_kwargs):
+            """Converted Python function from Relax function."""
+            # Handle arguments
+            if len(args) != len(params):
+                raise ValueError(f"Expected {len(params)} arguments, got 
{len(args)}")
+
+            # Use cached converter or create new one
+            if func_name not in self._converter_cache:
+                self._converter_cache[func_name] = RelaxExpressionConverter(
+                    self.operator_map, self.ir_module, self._op_cache
+                )
+
+            # Execute the converted function body
+            converter = self._converter_cache[func_name]
+            converter.current_params = params
+            return converter.convert_expr(relax_func.body, args)
+
+        # Set function metadata
+        converted_function.__name__ = func_name
+        converted_function.__doc__ = f"Converted Python function from Relax 
function: {func_name}"
+
+        return converted_function
+
+    @staticmethod
+    def _get_op_map() -> Dict[str, str]:
+        """Get the mapping from Relax operators to PyTorch operators."""
+        return {
+            # Binary operations
+            "relax.add": "torch.add",
+            "relax.subtract": "torch.sub",
+            "relax.multiply": "torch.mul",
+            "relax.divide": "torch.div",
+            "relax.power": "torch.pow",
+            "relax.maximum": "torch.maximum",
+            "relax.minimum": "torch.minimum",
+            "relax.floor_divide": "torch.floor_divide",
+            "relax.mod": "torch.fmod",
+            "relax.floor_mod": "torch.remainder",
+            "relax.log_add_exp": "torch.logaddexp",
+            # Bitwise operations
+            "relax.bitwise_and": "torch.bitwise_and",
+            "relax.bitwise_or": "torch.bitwise_or",
+            "relax.bitwise_xor": "torch.bitwise_xor",
+            "relax.left_shift": "torch.left_shift",
+            "relax.right_shift": "torch.right_shift",
+            # Unary operations
+            "relax.abs": "torch.abs",
+            "relax.negative": "torch.neg",
+            "relax.exp": "torch.exp",
+            "relax.log": "torch.log",
+            "relax.sqrt": "torch.sqrt",
+            "relax.rsqrt": "torch.rsqrt",
+            "relax.sin": "torch.sin",
+            "relax.cos": "torch.cos",
+            "relax.tanh": "torch.tanh",
+            "relax.sigmoid": "torch.sigmoid",
+            "relax.square": "torch.square",
+            "relax.sign": "torch.sign",
+            "relax.floor": "torch.floor",
+            "relax.ceil": "torch.ceil",
+            "relax.round": "torch.round",
+            "relax.trunc": "torch.trunc",
+            "relax.clip": "torch.clamp",
+            "relax.bitwise_not": "torch.bitwise_not",
+            # Trigonometric functions
+            "relax.acos": "torch.acos",
+            "relax.asin": "torch.asin",
+            "relax.atan": "torch.atan",
+            "relax.cosh": "torch.cosh",
+            "relax.sinh": "torch.sinh",
+            "relax.tan": "torch.tan",
+            "relax.acosh": "torch.acosh",
+            "relax.asinh": "torch.asinh",
+            "relax.atanh": "torch.atanh",
+            # Special functions
+            "relax.erf": "torch.erf",
+            "relax.isfinite": "torch.isfinite",
+            "relax.isinf": "torch.isinf",
+            "relax.isnan": "torch.isnan",
+            # Neural network operations
+            "relax.nn.relu": "F.relu",
+            "relax.nn.relu6": "F.relu6",
+            "relax.nn.gelu": "F.gelu",
+            "relax.nn.gelu_tanh": "F.gelu",
+            "relax.nn.softmax": "F.softmax",
+            "relax.nn.log_softmax": "F.log_softmax",
+            "relax.nn.dropout": "F.dropout",
+            "relax.nn.batch_norm": "F.batch_norm",
+            "relax.nn.layer_norm": "F.layer_norm",
+            "relax.nn.group_norm": "F.group_norm",
+            "relax.nn.instance_norm": "F.instance_norm",
+            "relax.nn.rms_norm": "F.layer_norm",  # Approximate mapping
+            "relax.nn.linear": "F.linear",
+            "relax.nn.conv1d": "F.conv1d",
+            "relax.nn.conv2d": "F.conv2d",
+            "relax.nn.conv3d": "F.conv3d",
+            "relax.nn.conv1d_transpose": "F.conv_transpose1d",
+            "relax.nn.conv2d_transpose": "F.conv_transpose2d",
+            "relax.nn.conv3d_transpose": "F.conv_transpose3d",
+            "relax.nn.max_pool1d": "F.max_pool1d",
+            "relax.nn.max_pool2d": "F.max_pool2d",
+            "relax.nn.max_pool3d": "F.max_pool3d",
+            "relax.nn.avg_pool1d": "F.avg_pool1d",
+            "relax.nn.avg_pool2d": "F.avg_pool2d",
+            "relax.nn.avg_pool3d": "F.avg_pool3d",
+            "relax.nn.adaptive_avg_pool1d": "F.adaptive_avg_pool1d",
+            "relax.nn.adaptive_avg_pool2d": "F.adaptive_avg_pool2d",
+            "relax.nn.adaptive_avg_pool3d": "F.adaptive_avg_pool3d",
+            "relax.nn.leakyrelu": "F.leaky_relu",
+            "relax.nn.prelu": "F.prelu",
+            "relax.nn.selu": "F.selu",
+            "relax.nn.silu": "F.silu",
+            "relax.nn.softplus": "F.softplus",
+            "relax.nn.attention": "F.scaled_dot_product_attention",  # 
Approximate mapping
+            "relax.nn.cross_entropy_with_logits": "F.cross_entropy",
+            "relax.nn.nll_loss": "F.nll_loss",
+            "relax.nn.pad": "F.pad",
+            "relax.nn.pixel_shuffle": "F.pixel_shuffle",
+            # Tensor operations
+            "relax.matmul": "torch.matmul",
+            "relax.linear": "F.linear",
+            "relax.einsum": "torch.einsum",
+            "relax.outer": "torch.outer",
+            "relax.reshape": "reshape",  # Special handling needed
+            "relax.permute_dims": "permute_dims",  # Special handling needed
+            "relax.expand_dims": "expand_dims",  # Special handling needed
+            "relax.squeeze": "squeeze",  # Special handling needed
+            "relax.concat": "concat",  # Special handling needed
+            "relax.split": "split",  # Special handling needed
+            "relax.stack": "stack",  # Special handling needed
+            "relax.tile": "tile",  # Special handling needed
+            "relax.repeat": "repeat",  # Special handling needed
+            "relax.broadcast_to": "torch.broadcast_to",
+            "relax.flatten": "torch.flatten",
+            "relax.flip": "flip",  # Special handling needed
+            "relax.roll": "torch.roll",
+            "relax.rot90": "torch.rot90",
+            "relax.meshgrid": "torch.meshgrid",
+            "relax.one_hot": "F.one_hot",
+            "relax.layout_transform": "torch.permute",  # Approximate mapping
+            # Indexing operations
+            "relax.take": "take",  # Special handling needed
+            "relax.gather_elements": "torch.gather",
+            "relax.gather_nd": "torch.gather",
+            "relax.scatter_elements": "torch.scatter",
+            "relax.scatter_nd": "torch.scatter",
+            "relax.index_put": "torch.index_put",
+            "relax.index_tensor": "torch.index_select",
+            "relax.strided_slice": "torch.slice",
+            "relax.dynamic_strided_slice": "torch.slice",
+            "relax.slice_scatter": "torch.scatter",
+            # Reduction operations
+            "relax.sum": "sum",  # Special handling needed
+            "relax.mean": "mean",  # Special handling needed
+            "relax.max": "max",  # Special handling needed
+            "relax.min": "min",  # Special handling needed
+            "relax.prod": "torch.prod",
+            "relax.std": "std",  # Special handling needed
+            "relax.variance": "variance",  # Special handling needed
+            "relax.cumsum": "torch.cumsum",
+            "relax.cumprod": "torch.cumprod",
+            "relax.argmax": "torch.argmax",
+            "relax.argmin": "torch.argmin",
+            # Comparison operations
+            "relax.equal": "torch.eq",
+            "relax.not_equal": "torch.ne",
+            "relax.greater": "torch.gt",
+            "relax.greater_equal": "torch.ge",
+            "relax.less": "torch.lt",
+            "relax.less_equal": "torch.le",
+            # Logical operations
+            "relax.logical_and": "torch.logical_and",
+            "relax.logical_or": "torch.logical_or",
+            "relax.logical_not": "torch.logical_not",
+            "relax.logical_xor": "torch.logical_xor",
+            # Creation operations
+            "relax.zeros": "torch.zeros",
+            "relax.ones": "torch.ones",
+            "relax.full": "torch.full",
+            "relax.full_like": "torch.full_like",
+            "relax.zeros_like": "torch.zeros_like",
+            "relax.ones_like": "torch.ones_like",
+            "relax.arange": "torch.arange",
+            "relax.eye": "torch.eye",
+            "relax.eye_like": "torch.eye",
+            "relax.tril": "torch.tril",
+            "relax.triu": "torch.triu",
+            "relax.hamming_window": "torch.hamming_window",
+            # Search operations
+            "relax.where": "torch.where",
+            "relax.bucketize": "torch.bucketize",
+            "relax.nonzero": "torch.nonzero",
+            "relax.unique": "torch.unique",
+            # Sorting operations
+            "relax.sort": "torch.sort",
+            "relax.argsort": "torch.argsort",
+            "relax.topk": "torch.topk",
+            # Sampling operations
+            "relax.multinomial_from_uniform": "torch.multinomial",
+            # Ternary operations
+            "relax.ewise_fma": "torch.fma",  # Approximate mapping
+            # Data type operations
+            "relax.astype": "torch.to",
+            "relax.wrap_param": "torch.tensor",
+            # Mask operations
+            "relax.masked_fill": "torch.masked_fill",
+            # Quantization operations
+            "relax.quantize": "torch.quantize_per_tensor",  # Approximate 
mapping
+            "relax.dequantize": "torch.dequantize",  # Approximate mapping
+            # Special operations (handled separately)
+            "relax.call_tir": "call_tir",
+            "relax.call_tir_inplace": "call_tir_inplace",
+            "relax.call_dps_packed": "call_dps_packed",
+            "relax.call_pure_packed": "call_pure_packed",
+            "relax.call_tir_with_grad": "call_tir_with_grad",
+            "relax.call_builtin_with_ctx": "call_builtin_with_ctx",
+            "relax.call_inplace_packed": "call_inplace_packed",
+            "relax.invoke_closure": "invoke_closure",
+            "relax.invoke_pure_closure": "invoke_pure_closure",
+            "relax.make_closure": "make_closure",
+            "relax.null_value": "null_value",
+            "relax.print": "print",
+            "relax.shape_of": "shape_of",
+            "relax.shape_to_tensor": "shape_to_tensor",
+            "relax.tensor_to_shape": "tensor_to_shape",
+            "relax.to_vdevice": "to_vdevice",
+            "relax.hint_on_device": "hint_on_device",
+            "relax.assert_op": "assert_op",
+        }
+
+
+class RelaxExpressionConverter:
+    """Converter that transforms Relax expressions to Python/PyTorch code."""
+
+    def __init__(
+        self,
+        operator_map: Dict[str, str],
+        ir_module: IRModule = None,
+        op_cache: Dict[str, str] = None,
+    ):
+        """Initialize the expression converter.
+
+        Args:
+            operator_map: Mapping from Relax operators to PyTorch operators
+            ir_module: The IRModule containing TIR functions to compile
+            op_cache: Shared cache for operator mappings to avoid repeated 
lookups
+        """
+        self.operator_map = operator_map
+        self.variable_map: Dict[str, Any] = {}
+        self.current_params: List[relax.Var] = []
+        self.ir_module = ir_module
+        # Use shared operator cache or create new one
+        self._op_cache = op_cache if op_cache is not None else {}
+
+    def convert_expr(self, expr: relax.Expr, args: List[Any]) -> Any:
+        """Convert a Relax expression to Python/PyTorch equivalent."""
+        if isinstance(expr, relax.Var):
+            return self._convert_var(expr, args)
+        elif isinstance(expr, relax.Call):
+            return self._convert_call(expr, args)
+        elif isinstance(expr, relax.Constant):
+            return self._convert_constant(expr)
+        elif isinstance(expr, relax.SeqExpr):
+            return self._convert_seq_expr(expr, args)
+        elif isinstance(expr, relax.Tuple):
+            return self._convert_tuple(expr, args)
+        elif isinstance(expr, relax.TupleGetItem):
+            return self._convert_tuple_get_item(expr, args)
+        elif isinstance(expr, relax.If):
+            return self._convert_if(expr, args)
+        elif isinstance(expr, relax.ShapeExpr):
+            return self._convert_shape_expr(expr)
+        else:
+            # Fallback for unknown expression types
+            return f"<unknown_expr: {type(expr).__name__}>"
+
+    def _convert_var(self, var: relax.Var, args: List[Any]) -> Any:
+        """Convert a Relax variable to Python equivalent."""
+        if hasattr(var, "name_hint"):
+            var_name = var.name_hint
+
+            # Check if it's a function parameter
+            for i, param in enumerate(self.current_params):
+                if hasattr(param, "name_hint") and param.name_hint == var_name:
+                    return args[i]
+
+            # Check if it's a bound variable
+            if var_name in self.variable_map:
+                return self.variable_map[var_name]
+
+            # Return placeholder for unbound variables
+            return f"<unbound_var: {var_name}>"
+        return f"<var: {var}>"
+
+    def _convert_call(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert a Relax call to Python/PyTorch equivalent."""
+        op = call.op
+
+        # Handle different types of calls
+        if isinstance(op, relax.GlobalVar):
+            # Function call
+            return self._convert_function_call(call, args)
+        elif isinstance(op, Op):
+            # Operator call
+            return self._convert_operator_call(call, args)
+        elif isinstance(op, relax.ExternFunc):
+            # External function call (like call_tir, call_dps_packed)
+            return self._convert_extern_func_call(call, args)
+        else:
+            return f"<call: {type(op).__name__}>"
+
+    def _convert_function_call(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert a Relax function call."""
+        func_name = call.op.name_hint
+        call_args = [self.convert_expr(arg, args) for arg in call.args]
+
+        # Handle special cases
+        if func_name in ["call_tir", "call_tir_inplace"]:
+            return self._convert_call_tir(call, args)
+        elif func_name in ["call_dps_packed", "call_pure_packed"]:
+            return self._convert_call_dps_packed(call, args)
+        else:
+            # Regular function call
+            return f"<func_call: {func_name}({', '.join(map(str, 
call_args))})>"
+
+    def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert a Relax operator call to PyTorch equivalent."""
+        op_name = call.op.name
+        call_args = [self.convert_expr(arg, args) for arg in call.args]
+
+        # Use cached operator mapping or look it up
+        if op_name not in self._op_cache:
+            self._op_cache[op_name] = self.operator_map.get(op_name)
+        pytorch_op = self._op_cache[op_name]
+        if pytorch_op:
+            try:
+                # Handle special operations
+                if pytorch_op == "call_tir":
+                    return self._convert_call_tir(call, args)
+                elif pytorch_op == "call_tir_inplace":
+                    return self._convert_call_tir(call, args)
+                elif pytorch_op == "call_dps_packed":
+                    return self._convert_call_dps_packed(call, args)
+                elif pytorch_op == "call_pure_packed":
+                    return self._convert_call_dps_packed(call, args)
+                elif pytorch_op == "expand_dims":
+                    return self._convert_expand_dims(call, args)
+                elif pytorch_op in ["sum", "mean", "max", "min", "std", 
"variance"]:
+                    return self._convert_reduction_op(call, args, pytorch_op)
+                elif pytorch_op == "squeeze":
+                    return self._convert_squeeze(call, args)
+                elif pytorch_op in ["concat", "split", "stack"]:
+                    return self._convert_tensor_ops(call, args, pytorch_op)
+                elif pytorch_op == "reshape":
+                    return self._convert_reshape(call, args)
+                elif pytorch_op == "permute_dims":
+                    return self._convert_permute_dims(call, args)
+                elif pytorch_op == "take":
+                    return self._convert_take(call, args)
+                elif pytorch_op == "flip":
+                    return self._convert_flip(call, args)
+                elif pytorch_op == "tile":
+                    return self._convert_tile(call, args)
+                elif pytorch_op == "repeat":
+                    return self._convert_repeat(call, args)
+                # Handle special cases for PyTorch operations
+                elif pytorch_op.startswith("F."):
+                    return self._handle_functional_operation(pytorch_op, call, 
call_args)
+                elif pytorch_op.startswith("torch."):
+                    # Regular PyTorch operation
+                    func_name = pytorch_op[6:]  # Remove "torch." prefix
+                    func = getattr(torch, func_name)
+                    return func(*call_args)
+                else:
+                    # Direct function reference - use getattr for safer access
+                    if pytorch_op.startswith("torch."):
+                        module = torch
+                        func_name = pytorch_op[6:]  # Remove "torch." prefix
+                    elif pytorch_op.startswith("F."):
+                        module = F
+                        func_name = pytorch_op[2:]  # Remove "F." prefix
+                    else:
+                        return (
+                            f"<exec_error: {pytorch_op}({', '.join(map(str, 
call_args))}) "
+                            f"- unsupported operation>"
+                        )
+
+                    func = getattr(module, func_name, None)
+                    if func is None:
+                        return (
+                            f"<exec_error: {pytorch_op}({', '.join(map(str, 
call_args))}) "
+                            f"- function not found>"
+                        )
+                    return func(*call_args)
+            except (AttributeError, TypeError, ValueError) as error:
+                # This allows the test framework to catch and handle the 
errors appropriately
+                if pytorch_op.startswith("torch.") or 
pytorch_op.startswith("F."):
+                    raise error
+                # Fallback to string representation for non-PyTorch operations
+                return f"<exec_error: {pytorch_op}({', '.join(map(str, 
call_args))}) - {error}>"
+        else:
+            # Unknown operator
+            return f"<unknown_op: {op_name}({', '.join(map(str, call_args))})>"
+
+    def _handle_functional_operation(
+        self, pytorch_op: str, call: relax.Call, call_args: List[Any]
+    ) -> Any:
+        """Handle PyTorch functional operations with special parameter 
handling."""
+        # Neural network function
+        func_name = pytorch_op[2:]  # Remove "F." prefix
+        func = getattr(F, func_name)
+
+        # Special handling for functions that need dim parameter
+        if func_name in ["softmax", "log_softmax"]:
+            # Extract axis from call.attrs and convert to dim
+            axis = None
+            if call.attrs and hasattr(call.attrs, "axis"):
+                axis = call.attrs.axis
+                if hasattr(axis, "value"):
+                    axis = int(axis.value)
+                elif isinstance(axis, (int, float)):
+                    axis = int(axis)
+
+            if axis is not None:
+                return func(call_args[0], dim=axis)
+            else:
+                # Default to last dimension if no axis specified
+                return func(call_args[0], dim=-1)
+        else:
+            return func(*call_args)
+
+    def _convert_extern_func_call(self, call: relax.Call, args: List[Any]) -> 
Any:
+        """Convert an external function call."""
+        func_name = call.op.global_symbol
+        call_args = [self.convert_expr(arg, args) for arg in call.args]
+
+        if func_name in ["call_tir", "call_tir_inplace"]:
+            return self._convert_call_tir(call, args)
+        elif func_name in ["call_dps_packed", "call_pure_packed"]:
+            return self._convert_call_dps_packed(call, args)
+        else:
+            return f"<extern_func: {func_name}({', '.join(map(str, 
call_args))})>"
+
+    def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert call_tir to Python equivalent with DLPack conversion."""
+        # Extract TIR function name and arguments
+        tir_func = call.args[0]
+        tir_args = call.args[1] if len(call.args) > 1 else []
+        out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None
+
+        # Get function name
+        if isinstance(tir_func, relax.GlobalVar):
+            func_name = tir_func.name_hint
+        else:
+            # Convert the GlobalVar expression
+            func_name = self.convert_expr(tir_func, args)
+            if isinstance(func_name, str) and func_name.startswith("<"):
+                # If it's a placeholder, extract the name
+                func_name = str(tir_func)
+
+        # Convert arguments to PyTorch tensors
+        converted_args = [self.convert_expr(arg, args) for arg in tir_args]
+
+        try:
+            # First, try to get the TIR function from the current IRModule
+            tir_function = None
+            if self.ir_module:
+                # Look for the TIR function in the current IRModule
+                for global_var, func in self.ir_module.functions.items():
+                    if global_var.name_hint == func_name and hasattr(func, 
"body"):
+                        try:
+                            # Compile the TIR function
+                            target = tvm.target.Target("llvm")
+                            with tvm.target.Target(target):
+                                tir_function = tvm.compile(func, target=target)
+                            break
+                        except (RuntimeError, ValueError, TypeError) as 
compile_e:
+                            print(
+                                f"Warning: Failed to compile TIR function 
{func_name}: {compile_e}"
+                            )
+                            continue
+
+            # If not found in current module, try global registry
+            if tir_function is None:
+                tir_function = tvm.get_global_func(func_name)
+
+            if tir_function is None:
+                return (
+                    f"<call_tir_error: {func_name} - Cannot find or compile 
function {func_name}>"
+                )
+
+            # Convert PyTorch tensors to TVM NDArrays via DLPack
+            tvm_args = []
+            for arg in converted_args:
+                if isinstance(arg, torch.Tensor):
+                    # Convert PyTorch tensor to TVM NDArray via DLPack
+                    tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg))
+                    tvm_args.append(tvm_arg)
+                else:
+                    tvm_args.append(arg)
+
+            # For call_tir, we need to allocate output tensor
+            output_shape = None
+            if out_sinfo and hasattr(out_sinfo, "shape"):
+                output_shape = out_sinfo.shape
+            elif converted_args:
+                # Use the shape of the first input tensor
+                first_arg = converted_args[0]
+                if isinstance(first_arg, torch.Tensor):
+                    output_shape = first_arg.shape
+
+            if output_shape is None:
+                return f"<call_tir_error: {func_name} - Cannot determine 
output shape>"
+
+            # Allocate output tensor
+            output_tensor = tvm.nd.array(tvm.nd.empty(output_shape, 
dtype="float32"))
+            tvm_args.append(output_tensor)
+
+            # Call the TIR function
+            tir_function(*tvm_args)
+
+            # The result is in the output_tensor we allocated
+            # Convert result back to PyTorch tensor via DLPack
+            return torch.from_dlpack(output_tensor.to_dlpack())
+
+        except (RuntimeError, ValueError, TypeError) as error:
+            return f"<call_tir_error: {func_name} - {error}>"
+
+    def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> 
Any:
+        """Convert call_dps_packed to Python equivalent with DLPack 
conversion."""
+        # Extract packed function name and arguments
+        packed_func = call.args[0]
+        packed_args = call.args[1] if len(call.args) > 1 else []
+        _out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None
+
+        # Get function name
+        if isinstance(packed_func, relax.GlobalVar):
+            func_name = packed_func.name_hint
+        elif isinstance(packed_func, relax.ExternFunc):
+            func_name = packed_func.global_symbol
+        else:
+            func_name = str(packed_func)
+
+        # Convert arguments to PyTorch tensors
+        converted_args = [self.convert_expr(arg, args) for arg in packed_args]
+
+        try:
+            # Get the packed function from TVM
+            packed_function = tvm.get_global_func(func_name)
+            if packed_function is None:
+                return f"<call_dps_packed_error: Function {func_name} not 
found>"
+
+            # Convert PyTorch tensors to TVM NDArrays via DLPack
+            tvm_args = []
+            for arg in converted_args:
+                if isinstance(arg, torch.Tensor):
+                    # Convert PyTorch tensor to TVM NDArray via DLPack
+                    tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg))
+                    tvm_args.append(tvm_arg)
+                else:
+                    tvm_args.append(arg)
+
+            # Call the packed function
+            result = packed_function(*tvm_args)
+
+            # Convert result back to PyTorch tensor via DLPack
+            if isinstance(result, tvm.nd.NDArray):
+                return torch.from_dlpack(result.to_dlpack())
+            else:
+                return result
+
+        except (RuntimeError, ValueError, TypeError) as error:
+            return f"<call_dps_packed_error: {func_name} - {error}>"
+
+    def _convert_constant(self, const: relax.Constant) -> Any:
+        """Convert a Relax constant to Python equivalent."""
+        if hasattr(const, "data"):
+            data = const.data
+            # Convert TVM NDArray to Python scalar if it's a scalar
+            if hasattr(data, "numpy"):
+                numpy_data = data.numpy()
+                if numpy_data.size == 1:
+                    return float(numpy_data.item())
+                else:
+                    # For multi-element arrays, convert to PyTorch tensor
+                    return torch.from_numpy(numpy_data)
+            elif hasattr(data, "item"):
+                # Single element tensor
+                return data.item()
+            else:
+                return data
+        return f"<const: {const}>"
+
+    def _convert_seq_expr(self, seq: relax.SeqExpr, args: List[Any]) -> Any:
+        """Convert a Relax sequence expression."""
+        # Convert blocks
+        for block in seq.blocks:
+            if hasattr(block, "bindings"):
+                for binding in block.bindings:
+                    if isinstance(binding, relax.VarBinding):
+                        var_name = binding.var.name_hint
+                        value = self.convert_expr(binding.value, args)
+                        self.variable_map[var_name] = value
+
+        # Convert body
+        return self.convert_expr(seq.body, args)
+
+    def _convert_tuple(self, tuple_expr: relax.Tuple, args: List[Any]) -> Any:
+        """Convert a Relax tuple to Python tuple."""
+        elements = [self.convert_expr(elem, args) for elem in 
tuple_expr.fields]
+        return tuple(elements)
+
+    def _convert_tuple_get_item(self, get_item: relax.TupleGetItem, args: 
List[Any]) -> Any:
+        """Convert a Relax tuple get item to Python equivalent."""
+        tuple_expr = self.convert_expr(get_item.tuple_value, args)
+        index = get_item.index
+        return f"<tuple_get_item: {tuple_expr}[{index}]>"
+
+    def _convert_if(self, if_expr: relax.If, args: List[Any]) -> Any:
+        """Convert a Relax if expression to Python equivalent."""
+        condition = self.convert_expr(if_expr.cond, args)
+        true_branch = self.convert_expr(if_expr.true_branch, args)
+        false_branch = self.convert_expr(if_expr.false_branch, args)
+        return f"<if: {condition} ? {true_branch} : {false_branch}>"
+
+    def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert expand_dims to torch.unsqueeze with proper axis handling."""
+        if len(call.args) < 1:
+            return "<expand_dims_error: insufficient arguments>"
+
+        # Convert the tensor argument
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Get the axis from call.attrs
+        axis = None
+        if call.attrs and hasattr(call.attrs, "axis"):
+            axis = call.attrs.axis
+            # Handle different types of axis
+            if hasattr(axis, "__iter__") and not isinstance(axis, str):
+                # It's an array/list, take the first element
+                axis = list(axis)[0] if len(axis) > 0 else None
+
+            # Handle TVM types
+            if hasattr(axis, "value"):
+                # It's a TVM IntImm or similar, get the value
+                axis = int(axis.value)
+            elif isinstance(axis, (int, float)):
+                axis = int(axis)
+
+        if axis is None:
+            return "<expand_dims_error: cannot determine axis>"
+
+        # Use torch.unsqueeze with the correct axis
+        return torch.unsqueeze(tensor_arg, dim=axis)
+
+    def _convert_reduction_op(self, call: relax.Call, args: List[Any], 
op_name: str) -> Any:
+        """Convert reduction operations with axis and keepdims parameters."""
+        if len(call.args) < 1:
+            return f"<{op_name}_error: insufficient arguments>"
+
+        # Convert the tensor argument
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Get axis and keepdims from call.attrs
+        axis = None
+        keepdims = False
+
+        if call.attrs:
+            if hasattr(call.attrs, "axis") and call.attrs.axis is not None:
+                axis = call.attrs.axis
+                # Handle different types of axis
+                if hasattr(axis, "__iter__") and not isinstance(axis, str):
+                    # It's an array/list, convert to list of ints
+                    axis = [
+                        int(item.value) if hasattr(item, "value") else 
int(item) for item in axis
+                    ]
+                elif hasattr(axis, "value"):
+                    # It's a TVM IntImm, get the value
+                    axis = int(axis.value)
+                elif isinstance(axis, (int, float)):
+                    axis = int(axis)
+
+            if hasattr(call.attrs, "keepdims"):
+                keepdims = bool(call.attrs.keepdims)
+
+        # Get the PyTorch function
+        func = getattr(torch, op_name)
+
+        # Call with appropriate parameters
+        if axis is not None:
+            # For max and min, PyTorch returns (values, indices) tuple when 
dim is specified
+            if op_name in ["max", "min"]:
+                if isinstance(axis, list) and len(axis) == 1:
+                    axis = axis[0]
+                elif isinstance(axis, list) and len(axis) > 1:
+                    axis = axis[0]
+                result = func(tensor_arg, axis, keepdim=keepdims)
+                if isinstance(result, tuple):
+                    return result[0]
+                else:
+                    return result
+            else:
+                return func(tensor_arg, dim=axis, keepdim=keepdims)
+        else:
+            return func(tensor_arg)
+
+    def _convert_squeeze(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert squeeze to torch.squeeze with proper axis handling."""
+        if len(call.args) < 1:
+            return "<squeeze_error: insufficient arguments>"
+
+        # Convert the tensor argument
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Get axis from call.attrs
+        axis = None
+        if call.attrs and hasattr(call.attrs, "axis") and call.attrs.axis is 
not None:
+            axis = call.attrs.axis
+            # Handle different types of axis
+            if hasattr(axis, "__iter__") and not isinstance(axis, str):
+                axis = [int(item.value) if hasattr(item, "value") else 
int(item) for item in axis]
+            elif hasattr(axis, "value"):
+                axis = int(axis.value)
+            elif isinstance(axis, (int, float)):
+                axis = int(axis)
+
+        # Call torch.squeeze with appropriate parameters
+        if axis is not None:
+            return torch.squeeze(tensor_arg, dim=axis)
+        else:
+            return torch.squeeze(tensor_arg)
+
+    def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: 
str) -> Any:
+        """Convert tensor operations like concat, split, stack."""
+        if len(call.args) < 1:
+            return f"<{op_name}_error: insufficient arguments>"
+
+        # Convert arguments
+        converted_args = [self.convert_expr(arg, args) for arg in call.args]
+
+        if op_name == "concat":
+            # torch.cat(tensors, dim=0)
+            # In Relax, concat takes a tuple of tensors as first argument
+            if len(converted_args) == 1 and isinstance(converted_args[0], 
tuple):
+                # This is a tuple of tensors
+                tensors = converted_args[0]
+            else:
+                # Direct tensor arguments
+                tensors = converted_args
+            axis = 0
+            if call.attrs and hasattr(call.attrs, "axis"):
+                axis = call.attrs.axis
+                if hasattr(axis, "value"):
+                    axis = int(axis.value)
+                elif isinstance(axis, (int, float)):
+                    axis = int(axis)
+            return torch.cat(tensors, dim=axis)
+
+        elif op_name == "split":
+            # torch.split(tensor, split_size_or_sections, dim=0)
+            tensor = converted_args[0]
+            split_size = converted_args[1] if len(converted_args) > 1 else 1
+            axis = 0
+            if call.attrs and hasattr(call.attrs, "axis"):
+                axis = call.attrs.axis
+                if hasattr(axis, "value"):
+                    axis = int(axis.value)
+                elif isinstance(axis, (int, float)):
+                    axis = int(axis)
+
+            # Handle indices_or_sections parameter
+            if call.attrs and hasattr(call.attrs, "indices_or_sections"):
+                indices_or_sections = call.attrs.indices_or_sections
+                if hasattr(indices_or_sections, "value"):
+                    indices_or_sections = int(indices_or_sections.value)
+                elif isinstance(indices_or_sections, (int, float)):
+                    indices_or_sections = int(indices_or_sections)
+
+                # If indices_or_sections is an integer, it means split into N 
equal parts
+                if isinstance(indices_or_sections, int):
+                    total_size = tensor.shape[axis]
+                    split_size = total_size // indices_or_sections
+                    return torch.split(tensor, split_size, dim=axis)
+                else:
+                    # If it's a list, use it directly
+                    return torch.split(tensor, indices_or_sections, dim=axis)
+            else:
+                return torch.split(tensor, split_size, dim=axis)
+
+        elif op_name == "stack":
+            # torch.stack(tensors, dim=0)
+            if len(converted_args) == 1 and isinstance(converted_args[0], 
tuple):
+                tensors = converted_args[0]
+            else:
+                tensors = converted_args
+            axis = 0
+            if call.attrs and hasattr(call.attrs, "axis"):
+                axis = call.attrs.axis
+                if hasattr(axis, "value"):
+                    axis = int(axis.value)
+                elif isinstance(axis, (int, float)):
+                    axis = int(axis)
+            return torch.stack(tensors, dim=axis)
+
+        else:
+            return f"<{op_name}_error: unsupported operation>"
+
+    def _convert_reshape(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert reshape operation."""
+        if len(call.args) < 2:
+            return "<reshape_error: insufficient arguments>"
+
+        tensor_arg = self.convert_expr(call.args[0], args)
+        shape_arg = call.args[1]
+
+        # Convert shape argument to Python tuple
+        if isinstance(shape_arg, relax.ShapeExpr):
+            if hasattr(shape_arg, "values"):
+                shape = tuple(
+                    int(v.value) if hasattr(v, "value") else int(v) for v in 
shape_arg.values
+                )
+            else:
+                shape = (int(shape_arg),)
+        elif isinstance(shape_arg, relax.Constant):
+            # Constant tensor case
+            shape_data = shape_arg.data.numpy()
+            shape = tuple(int(v) for v in shape_data)
+        else:
+            # Try to convert as expression
+            converted_shape = self.convert_expr(shape_arg, args)
+            if isinstance(converted_shape, (list, tuple)):
+                shape = tuple(int(v) for v in converted_shape)
+            else:
+                shape = (int(converted_shape),)
+
+        return torch.reshape(tensor_arg, shape)
+
+    def _convert_permute_dims(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert permute_dims operation."""
+        if len(call.args) < 1:
+            return "<permute_dims_error: insufficient arguments>"
+
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Extract axes from call.attrs
+        if call.attrs and hasattr(call.attrs, "axes"):
+            axes = call.attrs.axes
+            # Handle TVM Array type
+            if hasattr(axes, "__iter__") and not isinstance(axes, str):
+                # Convert TVM Array or Python list/tuple to tuple of ints
+                axes = tuple(int(v.value) if hasattr(v, "value") else int(v) 
for v in axes)
+            elif isinstance(axes, (list, tuple)):
+                axes = tuple(int(v) for v in axes)
+            else:
+                axes = (int(axes),)
+        else:
+            return "<permute_dims_error: no axes attribute>"
+
+        return torch.permute(tensor_arg, axes)
+
+    def _convert_take(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert take operation."""
+        if len(call.args) < 2:
+            return "<take_error: insufficient arguments>"
+
+        tensor_arg = self.convert_expr(call.args[0], args)
+        indices_arg = self.convert_expr(call.args[1], args)
+
+        # Extract axis from call.attrs
+        axis = None
+        if call.attrs and hasattr(call.attrs, "axis"):
+            axis = call.attrs.axis
+            if hasattr(axis, "value"):
+                axis = int(axis.value)
+            elif isinstance(axis, (int, float)):
+                axis = int(axis)
+
+        if axis is not None:
+            # Use advanced indexing for specific axis
+            if axis == 0:
+                return tensor_arg[indices_arg]
+            else:
+                # For other axes, we need to use torch.index_select
+                return torch.index_select(tensor_arg, dim=axis, 
index=indices_arg)
+        else:
+            # No axis specified, use torch.take (flattens the tensor)
+            return torch.take(tensor_arg, indices_arg)
+
+    def _convert_flip(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert flip operation."""
+        if len(call.args) < 1:
+            return "<flip_error: insufficient arguments>"
+
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Extract axis from call.attrs
+        axis = None
+        if call.attrs and hasattr(call.attrs, "axis"):
+            axis = call.attrs.axis
+            if hasattr(axis, "value"):
+                axis = int(axis.value)
+            elif isinstance(axis, (int, float)):
+                axis = int(axis)
+
+        if axis is not None:
+            # Convert single axis to list for torch.flip
+            dims = [axis]
+        else:
+            # Default: flip all dimensions
+            dims = list(range(tensor_arg.dim()))
+
+        return torch.flip(tensor_arg, dims=dims)
+
+    def _convert_tile(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert tile operation."""
+        if len(call.args) < 1:
+            return "<tile_error: insufficient arguments>"
+
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Extract repeats from call.attrs
+        if call.attrs and hasattr(call.attrs, "repeats"):
+            repeats = call.attrs.repeats
+            # Handle TVM Array type
+            if hasattr(repeats, "__iter__") and not isinstance(repeats, str):
+                repeats = tuple(int(v.value) if hasattr(v, "value") else 
int(v) for v in repeats)
+            elif isinstance(repeats, (list, tuple)):
+                repeats = tuple(int(v) for v in repeats)
+            else:
+                repeats = (int(repeats),)
+        else:
+            return "<tile_error: no repeats attribute>"
+
+        return torch.tile(tensor_arg, dims=repeats)
+
+    def _convert_repeat(self, call: relax.Call, args: List[Any]) -> Any:
+        """Convert repeat operation."""
+        if len(call.args) < 1:
+            return "<repeat_error: insufficient arguments>"
+
+        tensor_arg = self.convert_expr(call.args[0], args)
+
+        # Extract repeats and axis from call.attrs
+        repeats = 1
+        axis = None
+
+        if call.attrs and hasattr(call.attrs, "repeats"):
+            repeats = call.attrs.repeats
+            if hasattr(repeats, "value"):
+                repeats = int(repeats.value)
+            elif isinstance(repeats, (int, float)):
+                repeats = int(repeats)
+
+        if call.attrs and hasattr(call.attrs, "axis"):
+            axis = call.attrs.axis
+            if hasattr(axis, "value"):
+                axis = int(axis.value)
+            elif isinstance(axis, (int, float)):
+                axis = int(axis)
+
+        if axis is not None:
+            return torch.repeat_interleave(tensor_arg, repeats=repeats, 
dim=axis)
+        else:
+            return torch.repeat_interleave(tensor_arg, repeats=repeats)
+
+    def _convert_shape_expr(self, shape_expr: relax.ShapeExpr) -> Any:
+        """Convert a Relax shape expression to Python equivalent."""
+        if hasattr(shape_expr, "values"):
+            return f"<shape: ({', '.join(map(str, shape_expr.values))})>"
+        return f"<shape: {shape_expr}>"
+
+
+def convert_relax_to_pyfunc(
+    ir_module: IRModule, relax_function_names: Union[str, List[str]]
+) -> IRModule:
+    """Convert Relax functions to Python functions.
+
+    Args:
+        ir_module: The IRModule containing Relax functions
+        relax_function_names: Name(s) of Relax functions to convert
+
+    Returns:
+        IRModule with converted Python functions stored in pyfuncs
+
+    Example:
+        >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "my_function")
+        >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, ["func1", 
"func2"])
+    """
+    converter = RelaxToPyFuncConverter(ir_module)
+    return converter.convert(relax_function_names)
diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py 
b/tests/python/relax/test_relax_to_pyfunc_converter.py
new file mode 100644
index 0000000000..6dce309315
--- /dev/null
+++ b/tests/python/relax/test_relax_to_pyfunc_converter.py
@@ -0,0 +1,866 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Comprehensive test cases for Relax to PyFunc converter.
+Tests all major features including basic operations, call_tir, 
call_dps_packed, and symbolic shapes.
+"""
+
+
+import pytest
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+import tvm
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.script import relax as R
+from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter
+
+
[email protected]_module
+class ComprehensiveTestModule:
+    """Test module covering all converter features."""
+
+    @T.prim_func
+    def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+        """TIR function for addition."""
+        x = T.match_buffer(var_x, (5,), "float32")
+        y = T.match_buffer(var_y, (5,), "float32")
+        out = T.match_buffer(var_out, (5,), "float32")
+        for i in range(5):
+            out[i] = x[i] + y[i]
+
+    @T.prim_func
+    def mul_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+        """TIR function for multiplication."""
+        x = T.match_buffer(var_x, (3, 4), "float32")
+        y = T.match_buffer(var_y, (3, 4), "float32")
+        out = T.match_buffer(var_out, (3, 4), "float32")
+        for i in range(3):
+            for j in range(4):
+                out[i, j] = x[i, j] * y[i, j]
+
+    @R.function
+    def simple_add(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.add(x, y)
+
+    @R.function
+    def with_relu(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.nn.relu(x)
+
+    @R.function
+    def with_call_tir(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        cls = ComprehensiveTestModule
+        return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), 
"float32"))
+
+    @R.function
+    def with_call_dps_packed(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), 
"float32"):
+        return R.call_dps_packed(
+            "my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), 
"float32")
+        )
+
+    @R.function
+    def complex_function(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        added = R.add(x, y)
+        relued = R.nn.relu(added)
+        cls = ComprehensiveTestModule
+        tir_result = R.call_tir(cls.add_tir, (relued, y), 
out_sinfo=R.Tensor((5,), "float32"))
+        return R.nn.relu(tir_result)
+
+    @R.function
+    def symbolic_add(
+        x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32")
+    ) -> R.Tensor(("n",), "float32"):
+        return R.add(x, y)
+
+    @R.function
+    def symbolic_matmul(
+        x: R.Tensor(("batch", "m", "k"), "float32"), y: R.Tensor(("batch", 
"k", "n"), "float32")
+    ) -> R.Tensor(("batch", "m", "n"), "float32"):
+        return R.matmul(x, y)
+
+    @R.function
+    def symbolic_expand_dims(
+        x: R.Tensor(("batch", "seq_len"), "float32")
+    ) -> R.Tensor(("batch", "seq_len", 1), "float32"):
+        return R.expand_dims(x, axis=2)
+
+    @R.function
+    def multi_ops(
+        x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")
+    ) -> R.Tensor((3, 4), "float32"):
+        added = R.add(x, y)
+        multiplied = R.multiply(added, y)
+        powered = R.power(multiplied, R.const(2.0))
+        maxed = R.maximum(powered, x)
+        return maxed
+
+    @R.function
+    def reduction_ops(x: R.Tensor((5,), "float32")) -> R.Tensor((), "float32"):
+        sum_val = R.sum(x)
+        mean_val = R.mean(x)
+        max_val = R.max(x)
+        return R.add(R.add(sum_val, mean_val), max_val)
+
+    @R.function
+    def comparison_ops(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "bool"):
+        eq_val = R.equal(x, y)
+        gt_val = R.greater(x, y)
+        return R.logical_and(eq_val, gt_val)
+
+    @R.function
+    def test_reshape(x: R.Tensor((2, 3), "float32")) -> R.Tensor((6,), 
"float32"):
+        return R.reshape(x, (6,))
+
+    @R.function
+    def test_permute_dims(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((4, 
2, 3), "float32"):
+        return R.permute_dims(x, axes=[2, 0, 1])
+
+    @R.function
+    def test_concat(
+        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
+    ) -> R.Tensor((4, 3), "float32"):
+        return R.concat((x, y), axis=0)
+
+    @R.function
+    def test_split(x: R.Tensor((4, 3), "float32")) -> R.Tuple:
+        return R.split(x, indices_or_sections=2, axis=0)
+
+    @R.function
+    def test_stack(
+        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
+    ) -> R.Tensor((2, 2, 3), "float32"):
+        return R.stack((x, y), axis=1)
+
+    @R.function
+    def test_take(
+        x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64")
+    ) -> R.Tensor((2,), "float32"):
+        return R.take(x, indices, axis=0)
+
+    @R.function
+    def test_flip(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+        return R.flip(x, axis=1)
+
+    @R.function
+    def test_tile(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 6), 
"float32"):
+        return R.tile(x, (2, 2))
+
+    @R.function
+    def test_repeat(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), 
"float32"):
+        return R.repeat(x, repeats=2, axis=0)
+
+    @R.function
+    def test_expand_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3, 
1), "float32"):
+        return R.expand_dims(x, axis=2)
+
+    @R.function
+    def test_squeeze(x: R.Tensor((2, 3, 1), "float32")) -> R.Tensor((2, 3), 
"float32"):
+        return R.squeeze(x, axis=2)
+
+    @R.function
+    def test_sum_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), 
"float32"):
+        return R.sum(x, axis=0)
+
+    @R.function
+    def test_max_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), 
"float32"):
+        return R.max(x, axis=0)
+
+
+def create_mock_packed_function():
+    """Create a mock packed function for testing."""
+
+    def mock_softmax(x, axis):
+        """Mock softmax function that just returns the input."""
+        return x
+
+    # Register the function globally
+    tvm.register_func("my_softmax", mock_softmax)
+
+
+class TestRelaxToPyFuncConverter:
+    """Comprehensive test class for Relax to PyFunc converter."""
+
+    @classmethod
+    def setup_class(cls):
+        """Set up test fixtures."""
+        cls.ir_mod = ComprehensiveTestModule
+        cls.converter = RelaxToPyFuncConverter(cls.ir_mod)
+        create_mock_packed_function()
+
+    def test_basic_operations(self):
+        """Test basic arithmetic operations."""
+        converted_ir_mod = self.converter.convert(["simple_add", "with_relu"])
+
+        # Test simple_add
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
+
+        result = converted_ir_mod.pyfuncs["simple_add"](x, y)
+        expected = torch.add(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test with_relu
+        x_neg = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32)
+        result = converted_ir_mod.pyfuncs["with_relu"](x_neg)
+        expected = torch.nn.functional.relu(x_neg)
+        assert torch.allclose(result, expected)
+
+    def test_call_tir(self):
+        """Test call_tir functionality with DLPack conversion."""
+        converted_ir_mod = self.converter.convert(["with_call_tir"])
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
+
+        result = converted_ir_mod.pyfuncs["with_call_tir"](x, y)
+        expected = torch.add(x, y)
+        assert torch.allclose(result, expected)
+        assert result.shape == expected.shape
+
+    def test_call_dps_packed(self):
+        """Test call_dps_packed functionality."""
+        converted_ir_mod = self.converter.convert(["with_call_dps_packed"])
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+
+        result = converted_ir_mod.pyfuncs["with_call_dps_packed"](x)
+        expected = x
+        assert torch.allclose(result, expected)
+
+    def test_complex_function(self):
+        """Test complex function with multiple operations."""
+        converted_ir_mod = self.converter.convert(["complex_function"])
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
+
+        result = converted_ir_mod.pyfuncs["complex_function"](x, y)
+
+        # Expected: relu(add(relu(add(x, y)), y))
+        step1 = torch.add(x, y)
+        step2 = torch.nn.functional.relu(step1)
+        step3 = torch.add(step2, y)  # TIR call
+        expected = torch.nn.functional.relu(step3)
+
+        assert torch.allclose(result, expected)
+
+    def test_symbolic_shapes(self):
+        """Test symbolic shape handling."""
+        converted_ir_mod = self.converter.convert(
+            ["symbolic_add", "symbolic_matmul", "symbolic_expand_dims"]
+        )
+
+        # Test symbolic_add
+        x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+        y = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
+        result = converted_ir_mod.pyfuncs["symbolic_add"](x, y)
+        expected = torch.add(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test symbolic_matmul
+        x = torch.randn(2, 3, 4, dtype=torch.float32)  # (batch=2, m=3, k=4)
+        y = torch.randn(2, 4, 5, dtype=torch.float32)  # (batch=2, k=4, n=5)
+        result = converted_ir_mod.pyfuncs["symbolic_matmul"](x, y)
+        expected = torch.matmul(x, y)
+        assert torch.allclose(result, expected)
+        assert result.shape == (2, 3, 5)
+
+        # Test symbolic_expand_dims
+        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
+        result = converted_ir_mod.pyfuncs["symbolic_expand_dims"](x)
+        expected = torch.unsqueeze(x, dim=2)
+        assert torch.allclose(result, expected)
+        assert result.shape == (2, 2, 1)
+
+    def test_multi_operations(self):
+        """Test multiple operations in sequence."""
+        converted_ir_mod = self.converter.convert(["multi_ops"])
+
+        x = torch.tensor(
+            [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 
12.0]],
+            dtype=torch.float32,
+        )
+        y = torch.tensor(
+            [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 
1.2]], dtype=torch.float32
+        )
+
+        result = converted_ir_mod.pyfuncs["multi_ops"](x, y)
+
+        # Expected: maximum(power(multiply(add(x, y), y), 2), x)
+        step1 = torch.add(x, y)
+        step2 = torch.mul(step1, y)
+        step3 = torch.pow(step2, 2.0)
+        expected = torch.maximum(step3, x)
+
+        assert torch.allclose(result, expected)
+
+    def test_reduction_operations(self):
+        """Test reduction operations."""
+        converted_ir_mod = self.converter.convert(["reduction_ops"])
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+
+        result = converted_ir_mod.pyfuncs["reduction_ops"](x)
+
+        # Expected: sum(x) + mean(x) + max(x)
+        expected = torch.sum(x) + torch.mean(x) + torch.max(x)
+
+        assert torch.allclose(result, expected)
+        assert result.shape == ()
+
+    def test_comparison_operations(self):
+        """Test comparison operations."""
+        converted_ir_mod = self.converter.convert(["comparison_ops"])
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        y = torch.tensor([1.0, 2.5, 3.0, 4.5, 5.0], dtype=torch.float32)
+
+        result = converted_ir_mod.pyfuncs["comparison_ops"](x, y)
+
+        # Expected: logical_and(equal(x, y), greater(x, y))
+        eq_val = torch.eq(x, y)
+        gt_val = torch.gt(x, y)
+        expected = torch.logical_and(eq_val, gt_val)
+
+        assert torch.allclose(result, expected)
+        assert result.dtype == torch.bool
+
+    def test_operator_mapping_completeness(self):
+        """Test that operator mapping is comprehensive."""
+        operator_map = RelaxToPyFuncConverter._get_op_map()
+
+        # Check that we have a good number of operators
+        assert len(operator_map) > 100, f"Expected >100 operators, got 
{len(operator_map)}"
+
+        # Check key operator categories
+        binary_ops = [
+            op
+            for op in operator_map.keys()
+            if op.startswith("relax.") and not op.startswith("relax.nn.")
+        ]
+        nn_ops = [op for op in operator_map.keys() if 
op.startswith("relax.nn.")]
+
+        assert len(binary_ops) > 20, f"Expected >20 binary ops, got 
{len(binary_ops)}"
+        assert len(nn_ops) > 30, f"Expected >30 nn ops, got {len(nn_ops)}"
+
+        # Check specific important operators
+        important_ops = [
+            "relax.add",
+            "relax.multiply",
+            "relax.nn.relu",
+            "relax.nn.softmax",
+            "relax.matmul",
+            "relax.reshape",
+            "relax.sum",
+            "relax.mean",
+        ]
+
+        for op in important_ops:
+            assert op in operator_map, f"Missing important operator: {op}"
+
+    def test_error_handling(self):
+        """Test error handling for invalid inputs."""
+        converted_ir_mod = self.converter.convert(["simple_add"])
+
+        # Test with wrong number of arguments
+        x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+
+        with pytest.raises(ValueError, match="Expected 2 arguments"):
+            converted_ir_mod.pyfuncs["simple_add"](x)  # Missing second 
argument
+
+        # Test with incompatible shapes - this should raise a RuntimeError
+        x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+        y = torch.tensor([1.0, 2.0], dtype=torch.float32)  # Different shape
+
+        # This should raise a RuntimeError because shapes don't match
+        with pytest.raises(RuntimeError, match="The size of tensor a"):
+            converted_ir_mod.pyfuncs["simple_add"](x, y)
+
+    def test_conversion_metadata(self):
+        """Test that conversion preserves metadata correctly."""
+        converted_ir_mod = self.converter.convert(["simple_add"])
+
+        # Check that pyfuncs attribute exists
+        assert hasattr(converted_ir_mod, "pyfuncs")
+        assert "simple_add" in converted_ir_mod.pyfuncs
+
+        # Check function metadata
+        pyfunc = converted_ir_mod.pyfuncs["simple_add"]
+        assert hasattr(pyfunc, "__name__")
+        assert hasattr(pyfunc, "__doc__")
+        assert pyfunc.__name__ == "simple_add"
+
+    def test_tensor_operations(self):
+        """Test tensor manipulation operations."""
+        converted_ir_mod = self.converter.convert(
+            [
+                "test_reshape",
+                "test_permute_dims",
+                "test_concat",
+                "test_split",
+                "test_stack",
+                "test_take",
+                "test_flip",
+                "test_tile",
+                "test_repeat",
+                "test_expand_dims",
+                "test_squeeze",
+                "test_sum_with_axis",
+                "test_max_with_axis",
+            ]
+        )
+
+        # Test reshape
+        x1 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result1 = converted_ir_mod.pyfuncs["test_reshape"](x1)
+        expected1 = torch.reshape(x1, (6,))
+        assert torch.allclose(result1, expected1), "Reshape operation failed"
+
+        # Test permute_dims
+        x2 = torch.randn(2, 3, 4)
+        result2 = converted_ir_mod.pyfuncs["test_permute_dims"](x2)
+        expected2 = torch.permute(x2, (2, 0, 1))
+        assert torch.allclose(result2, expected2), "Permute_dims operation 
failed"
+
+        # Test concat
+        x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        y3 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], 
dtype=torch.float32)
+        result3 = converted_ir_mod.pyfuncs["test_concat"](x3, y3)
+        expected3 = torch.cat([x3, y3], dim=0)
+        assert torch.allclose(result3, expected3), "Concat operation failed"
+
+        # Test split
+        x4 = torch.tensor(
+            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 
12.0]],
+            dtype=torch.float32,
+        )
+        result4 = converted_ir_mod.pyfuncs["test_split"](x4)
+        expected4 = torch.split(x4, 2, dim=0)
+        assert len(result4) == len(expected4), "Split operation failed - wrong 
number of tensors"
+        for r, e in zip(result4, expected4):
+            assert torch.allclose(r, e), "Split operation failed - tensor 
mismatch"
+
+        # Test stack
+        x5 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        y5 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], 
dtype=torch.float32)
+        result5 = converted_ir_mod.pyfuncs["test_stack"](x5, y5)
+        expected5 = torch.stack([x5, y5], dim=1)
+        assert torch.allclose(result5, expected5), "Stack operation failed"
+
+        # Test take
+        x6 = torch.tensor(
+            [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 
12.0]],
+            dtype=torch.float32,
+        )
+        indices = torch.tensor([0, 2], dtype=torch.int64)
+        result6 = converted_ir_mod.pyfuncs["test_take"](x6, indices)
+        expected6 = x6[indices]
+        assert torch.allclose(result6, expected6), "Take operation failed"
+
+        # Test flip
+        x7 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result7 = converted_ir_mod.pyfuncs["test_flip"](x7)
+        expected7 = torch.flip(x7, dims=[1])
+        assert torch.allclose(result7, expected7), "Flip operation failed"
+
+        # Test tile
+        x8 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result8 = converted_ir_mod.pyfuncs["test_tile"](x8)
+        expected8 = torch.tile(x8, (2, 2))
+        assert torch.allclose(result8, expected8), "Tile operation failed"
+
+        # Test repeat
+        x9 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result9 = converted_ir_mod.pyfuncs["test_repeat"](x9)
+        expected9 = torch.repeat_interleave(x9, repeats=2, dim=0)
+        assert torch.allclose(result9, expected9), "Repeat operation failed"
+
+        # Test expand_dims
+        x10 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result10 = converted_ir_mod.pyfuncs["test_expand_dims"](x10)
+        expected10 = torch.unsqueeze(x10, dim=2)
+        assert torch.allclose(result10, expected10), "Expand_dims operation 
failed"
+
+        # Test squeeze
+        x11 = torch.tensor([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], 
dtype=torch.float32)
+        result11 = converted_ir_mod.pyfuncs["test_squeeze"](x11)
+        expected11 = torch.squeeze(x11, dim=2)
+        assert torch.allclose(result11, expected11), "Squeeze operation failed"
+
+        # Test sum with axis
+        x12 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result12 = converted_ir_mod.pyfuncs["test_sum_with_axis"](x12)
+        expected12 = torch.sum(x12, dim=0)
+        assert torch.allclose(result12, expected12), "Sum with axis operation 
failed"
+
+        # Test max with axis
+        x13 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+        result13 = converted_ir_mod.pyfuncs["test_max_with_axis"](x13)
+        expected13 = torch.max(x13, dim=0)[0]  # torch.max returns (values, 
indices)
+        assert torch.allclose(result13, expected13), "Max with axis operation 
failed"
+
+
[email protected]_module
+class ExtendedOperatorsModule:
+    """Extended test module with additional operators not covered in 
ComprehensiveTestModule."""
+
+    # Unary operations not covered in ComprehensiveTestModule
+    @R.function
+    def test_abs(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.abs(x)
+
+    @R.function
+    def test_neg(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.negative(x)
+
+    @R.function
+    def test_exp(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.exp(x)
+
+    @R.function
+    def test_log(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.log(x)
+
+    @R.function
+    def test_sqrt(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.sqrt(x)
+
+    @R.function
+    def test_sin(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.sin(x)
+
+    @R.function
+    def test_cos(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.cos(x)
+
+    @R.function
+    def test_tanh(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
+        return R.tanh(x)
+
+    @R.function
+    def test_sigmoid(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), 
"float32"):
+        return R.sigmoid(x)
+
+    # Comparison operations not covered in ComprehensiveTestModule
+    @R.function
+    def test_less(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "bool"):
+        return R.less(x, y)
+
+    @R.function
+    def test_not_equal(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "bool"):
+        return R.not_equal(x, y)
+
+    # Binary operations not covered in ComprehensiveTestModule
+    @R.function
+    def test_multiply(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.multiply(x, y)
+
+    @R.function
+    def test_divide(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.divide(x, y)
+
+    @R.function
+    def test_power(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.power(x, y)
+
+    @R.function
+    def test_maximum(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.maximum(x, y)
+
+    @R.function
+    def test_minimum(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.minimum(x, y)
+
+    @R.function
+    def test_subtract(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.subtract(x, y)
+
+    # Additional tensor operations with different parameters
+    @R.function
+    def test_transpose_2d(x: R.Tensor((2, 4), "float32")) -> R.Tensor((4, 2), 
"float32"):
+        return R.permute_dims(x, axes=[1, 0])
+
+    @R.function
+    def test_mean_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), 
"float32"):
+        return R.mean(x, axis=0)
+
+    @R.function
+    def test_min_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), 
"float32"):
+        return R.min(x, axis=0)
+
+    # Neural network operations not covered in ComprehensiveTestModule
+    @R.function
+    def test_gelu_nn(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), 
"float32"):
+        return R.nn.gelu(x)
+
+    @R.function
+    def test_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), 
"float32"):
+        return R.nn.softmax(x, axis=1)
+
+    @R.function
+    def test_log_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 
5), "float32"):
+        return R.nn.log_softmax(x, axis=1)
+
+    # Advanced tensor operations with different parameters
+    @R.function
+    def test_tile_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 9), 
"float32"):
+        return R.tile(x, (2, 3))
+
+    @R.function
+    def test_repeat_axis(x: R.Tensor((3,), "float32")) -> R.Tensor((6,), 
"float32"):
+        return R.repeat(x, repeats=2, axis=0)
+
+
+class TestExtendedOperators:
+    """Test class for extended operator coverage."""
+
+    @classmethod
+    def setup_class(cls):
+        """Set up test fixtures."""
+        cls.ir_mod = ExtendedOperatorsModule
+        cls.converter = RelaxToPyFuncConverter(cls.ir_mod)
+
+    def test_unary_operations(self):
+        """Test unary operations."""
+        converted_ir_mod = self.converter.convert(
+            ["test_abs", "test_neg", "test_exp", "test_log", "test_sqrt"]
+        )
+
+        x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32)
+
+        # Test abs
+        result = converted_ir_mod.pyfuncs["test_abs"](x)
+        expected = torch.abs(x)
+        assert torch.allclose(result, expected)
+
+        # Test negative
+        result = converted_ir_mod.pyfuncs["test_neg"](x)
+        expected = torch.neg(x)
+        assert torch.allclose(result, expected)
+
+        # Test exp
+        result = converted_ir_mod.pyfuncs["test_exp"](x)
+        expected = torch.exp(x)
+        assert torch.allclose(result, expected)
+
+        # Test log (with positive values)
+        x_pos = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        result = converted_ir_mod.pyfuncs["test_log"](x_pos)
+        expected = torch.log(x_pos)
+        assert torch.allclose(result, expected)
+
+        # Test sqrt
+        result = converted_ir_mod.pyfuncs["test_sqrt"](x_pos)
+        expected = torch.sqrt(x_pos)
+        assert torch.allclose(result, expected)
+
+    def test_trigonometric_operations(self):
+        """Test trigonometric operations."""
+        converted_ir_mod = self.converter.convert(
+            ["test_sin", "test_cos", "test_tanh", "test_sigmoid"]
+        )
+
+        x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0], dtype=torch.float32)
+
+        # Test sin
+        result = converted_ir_mod.pyfuncs["test_sin"](x)
+        expected = torch.sin(x)
+        assert torch.allclose(result, expected)
+
+        # Test cos
+        result = converted_ir_mod.pyfuncs["test_cos"](x)
+        expected = torch.cos(x)
+        assert torch.allclose(result, expected)
+
+        # Test tanh
+        result = converted_ir_mod.pyfuncs["test_tanh"](x)
+        expected = torch.tanh(x)
+        assert torch.allclose(result, expected)
+
+        # Test sigmoid
+        result = converted_ir_mod.pyfuncs["test_sigmoid"](x)
+        expected = torch.sigmoid(x)
+        assert torch.allclose(result, expected)
+
+    def test_comparison_operations(self):
+        """Test comparison operations not covered in 
ComprehensiveTestModule."""
+        converted_ir_mod = self.converter.convert(["test_less", 
"test_not_equal"])
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32)
+
+        # Test less
+        result = converted_ir_mod.pyfuncs["test_less"](x, y)
+        expected = torch.lt(x, y)
+        assert torch.equal(result, expected)
+
+        # Test not equal
+        result = converted_ir_mod.pyfuncs["test_not_equal"](x, y)
+        expected = torch.ne(x, y)
+        assert torch.equal(result, expected)
+
+    def test_binary_operations(self):
+        """Test binary operations."""
+        converted_ir_mod = self.converter.convert(
+            [
+                "test_multiply",
+                "test_divide",
+                "test_power",
+                "test_maximum",
+                "test_minimum",
+                "test_subtract",
+            ]
+        )
+
+        x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
+        y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32)
+
+        # Test multiply
+        result = converted_ir_mod.pyfuncs["test_multiply"](x, y)
+        expected = torch.mul(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test divide
+        result = converted_ir_mod.pyfuncs["test_divide"](x, y)
+        expected = torch.div(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test power
+        result = converted_ir_mod.pyfuncs["test_power"](x, y)
+        expected = torch.pow(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test maximum
+        result = converted_ir_mod.pyfuncs["test_maximum"](x, y)
+        expected = torch.maximum(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test minimum
+        result = converted_ir_mod.pyfuncs["test_minimum"](x, y)
+        expected = torch.minimum(x, y)
+        assert torch.allclose(result, expected)
+
+        # Test subtract
+        result = converted_ir_mod.pyfuncs["test_subtract"](x, y)
+        expected = torch.sub(x, y)
+        assert torch.allclose(result, expected)
+
+    def test_tensor_operations(self):
+        """Test tensor operations not covered in ComprehensiveTestModule."""
+        converted_ir_mod = self.converter.convert(["test_transpose_2d"])
+
+        x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], 
dtype=torch.float32)
+
+        # Test transpose
+        result = converted_ir_mod.pyfuncs["test_transpose_2d"](x)
+        expected = torch.transpose(x, 0, 1)
+        assert torch.allclose(result, expected)
+        assert result.shape == (4, 2)
+
+    def test_reduction_operations(self):
+        """Test reduction operations not covered in ComprehensiveTestModule."""
+        converted_ir_mod = self.converter.convert(["test_mean_axis", 
"test_min_axis"])
+
+        x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
dtype=torch.float32)
+
+        # Test mean
+        result = converted_ir_mod.pyfuncs["test_mean_axis"](x)
+        expected = torch.mean(x, dim=0)
+        assert torch.allclose(result, expected)
+        assert result.shape == (3,)
+
+        # Test min
+        result = converted_ir_mod.pyfuncs["test_min_axis"](x)
+        expected = torch.min(x, dim=0)[0]
+        assert torch.allclose(result, expected)
+        assert result.shape == (3,)
+
+    def test_neural_network_operations(self):
+        """Test neural network operations not covered in 
ComprehensiveTestModule."""
+        converted_ir_mod = self.converter.convert(
+            ["test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn"]
+        )
+
+        x = torch.tensor(
+            [[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], 
dtype=torch.float32
+        )
+
+        # Test gelu
+        result = converted_ir_mod.pyfuncs["test_gelu_nn"](x[0])
+        expected = F.gelu(x[0])
+        assert torch.allclose(result, expected)
+
+        # Test softmax
+        result = converted_ir_mod.pyfuncs["test_softmax_nn"](x)
+        expected = F.softmax(x, dim=1)
+        assert torch.allclose(result, expected)
+
+        # Test log_softmax
+        result = converted_ir_mod.pyfuncs["test_log_softmax_nn"](x)
+        expected = F.log_softmax(x, dim=1)
+        assert torch.allclose(result, expected)
+
+    def test_advanced_tensor_operations(self):
+        """Test advanced tensor operations with different parameters."""
+        converted_ir_mod = self.converter.convert(["test_tile_dims", 
"test_repeat_axis"])
+
+        x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], 
dtype=torch.float32)
+
+        # Test tile with different dimensions
+        result = converted_ir_mod.pyfuncs["test_tile_dims"](x)
+        expected = torch.tile(x, (2, 3))
+        assert torch.allclose(result, expected)
+        assert result.shape == (4, 12)
+
+        # Test repeat with different parameters
+        x_1d = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+        result = converted_ir_mod.pyfuncs["test_repeat_axis"](x_1d)
+        expected = torch.repeat_interleave(x_1d, repeats=2, dim=0)
+        assert torch.allclose(result, expected)
+        assert result.shape == (6,)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__, "-v"])

Reply via email to