This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 887a9ca817 [Relax] Building TVMScript printer for IRModules with 
Python functions (#18253)
887a9ca817 is described below

commit 887a9ca8172722bbf0156292c5d25a8822d66bda
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Sep 3 21:06:22 2025 -0400

    [Relax] Building TVMScript printer for IRModules with Python functions 
(#18253)
    
    This PR implements TVMScript printer to format IRModules
    containing `@I.pyfunc` decorated Python functions.
    
    Example:
    ```
    @I.ir_module
    class MyModule(BasePyModule):
        @I.pyfunc
        def python_func(self, x, y):
            x_tvm = self._convert_pytorch_to_tvm(x)
            y_tvm = self._convert_pytorch_to_tvm(y)
            result = self.call_tir(self.add_tir, [x_tvm, y_tvm],
                                 out_sinfo=R.Tensor((5,), "float32"))
            return self._convert_tvm_to_pytorch(result)
    
        @T.prim_func
        def add_tir(a: T.handle, b: T.handle, c: T.handle):
            A = T.match_buffer(a, (5,), "float32")
            B = T.match_buffer(b, (5,), "float32")
            C = T.match_buffer(c, (5,), "float32")
            for i in range(5):
                C[i] = A[i] + B[i]
    
    # Usage:
    print(MyModule.script())  # Print formatted TVMScript
    MyModule.show()           # Display formatted output
    ```
---
 python/tvm/relax/base_py_module.py                | 129 +++-
 tests/python/relax/test_base_py_module_printer.py | 760 ++++++++++++++++++++++
 2 files changed, 888 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/base_py_module.py 
b/python/tvm/relax/base_py_module.py
index 2ef17504c8..f463a84fc6 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -16,6 +16,8 @@
 # under the License.
 """BasePyModule: Base class for IRModules with Python function support."""
 
+import inspect
+import os
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
@@ -369,7 +371,6 @@ class BasePyModule:
         # Create a wrapper that handles both instance methods and static 
functions
         # pylint: disable=import-outside-toplevel
         import functools
-        import inspect
 
         @functools.wraps(func)
         def wrapper(*args, **kwargs):
@@ -383,3 +384,129 @@ class BasePyModule:
 
         # Set the wrapper as an instance attribute
         setattr(self, name, wrapper)
+
+    def script(
+        self,
+        *,
+        name: Optional[str] = None,
+        show_meta: bool = False,
+        ir_prefix: str = "I",
+        tir_prefix: str = "T",
+        relax_prefix: str = "R",
+        module_alias: str = "cls",
+        buffer_dtype: str = "float32",
+        int_dtype: str = "int32",
+        float_dtype: str = "void",
+        verbose_expr: bool = False,
+        indent_spaces: int = 4,
+        print_line_numbers: bool = False,
+        num_context_lines: int = -1,
+        syntax_sugar: bool = True,
+        show_object_address: bool = False,
+        show_all_struct_info: bool = True,
+    ) -> str:
+        """Print TVM IR into TVMScript text format with Python function 
support.
+
+        This method extends the standard IRModule script() method to handle
+        Python functions stored in the IRModule's pyfuncs attribute.
+        """
+        # First get the standard IRModule script
+        base_script = self.ir_mod.script(
+            name=name,
+            show_meta=show_meta,
+            ir_prefix=ir_prefix,
+            tir_prefix=tir_prefix,
+            relax_prefix=relax_prefix,
+            module_alias=module_alias,
+            buffer_dtype=buffer_dtype,
+            int_dtype=int_dtype,
+            float_dtype=float_dtype,
+            verbose_expr=verbose_expr,
+            indent_spaces=indent_spaces,
+            print_line_numbers=print_line_numbers,
+            num_context_lines=num_context_lines,
+            syntax_sugar=syntax_sugar,
+            show_object_address=show_object_address,
+            show_all_struct_info=show_all_struct_info,
+        )
+
+        # If there are no Python functions, return the base script
+        if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs:
+            return base_script
+
+        # Insert Python functions into the script
+        return self._insert_python_functions(base_script, indent_spaces)
+
+    def _insert_python_functions(self, base_script: str, indent_spaces: int) 
-> str:
+        """Insert Python functions into the TVMScript output."""
+        lines = base_script.split("\n")
+        result_lines = []
+
+        # Find the class definition line and insert Python functions after it
+        class_found = False
+        class_indent = 0
+
+        for line in lines:
+            result_lines.append(line)
+
+            # Look for class definition
+            if not class_found and line.strip().startswith("class "):
+                class_found = True
+                class_indent = len(line) - len(line.lstrip())
+
+                # Insert Python functions after the class definition
+                if hasattr(self.ir_mod, "pyfuncs") and self.ir_mod.pyfuncs:
+                    for func_name, func in self.ir_mod.pyfuncs.items():
+                        # Get the function source code
+                        func_source = self._get_function_source(func)
+                        if func_source:
+                            # Format the function with proper indentation
+                            formatted_func = self._format_python_function(
+                                func_name, func_source, class_indent + 
indent_spaces
+                            )
+                            result_lines.append(formatted_func)
+                            result_lines.append("")  # Add empty line for 
separation
+
+        return "\n".join(result_lines)
+
+    def _get_function_source(self, func: callable) -> Optional[str]:
+        """Get the source code of a Python function."""
+        try:
+            source = inspect.getsource(func)
+            return source
+        except (OSError, TypeError):
+            # If we can't get the source, return None
+            return None
+
+    def _format_python_function(self, _func_name: str, func_source: str, 
indent: int) -> str:
+        """Format a Python function with proper indentation for TVMScript."""
+        lines = func_source.split("\n")
+        formatted_lines = []
+
+        for line in lines:
+            # Skip the function definition line if it's already properly 
indented
+            if line.strip().startswith("def ") or line.strip().startswith("@"):
+                # Keep decorators and function definition as is
+                formatted_lines.append(" " * indent + line.strip())
+            else:
+                # Add proper indentation for the function body
+                formatted_lines.append(" " * indent + line.strip())
+
+        return "\n".join(formatted_lines)
+
+    def show(
+        self, style: Optional[str] = None, black_format: Optional[bool] = 
None, **kwargs
+    ) -> None:
+        """A sugar for print highlighted TVM script with Python function 
support.
+
+        This method extends the standard IRModule show() method to handle
+        Python functions stored in the IRModule's pyfuncs attribute.
+        """
+        from tvm.script.highlight import cprint  # pylint: 
disable=import-outside-toplevel
+
+        if black_format is None:
+            env = os.environ.get("TVM_BLACK_FORMAT")
+            black_format = env and int(env)
+
+        script_content = self.script(**kwargs)
+        cprint(script_content, style=style, black_format=black_format)
diff --git a/tests/python/relax/test_base_py_module_printer.py 
b/tests/python/relax/test_base_py_module_printer.py
new file mode 100644
index 0000000000..92c799f6cb
--- /dev/null
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -0,0 +1,760 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name, unused-argument
+
+import pytest
+import tvm
+from tvm.relax.base_py_module import BasePyModule
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.script import relax as R
+
+
[email protected]_module
+class SimplePyFuncModule(BasePyModule):
+    """Test simple Python functions with basic operations."""
+
+    @I.pyfunc
+    def add(self, x, y):
+        """Simple addition function."""
+        x_tvm = self._convert_pytorch_to_tvm(x)
+        y_tvm = self._convert_pytorch_to_tvm(y)
+        result = self.call_tir(self.add_tir, [x_tvm, y_tvm], 
out_sinfo=R.Tensor((5,), "float32"))
+        return self._convert_tvm_to_pytorch(result)
+
+    @I.pyfunc
+    def multiply(self, x, y):
+        """Simple multiplication function."""
+        x_tvm = self._convert_pytorch_to_tvm(x)
+        y_tvm = self._convert_pytorch_to_tvm(y)
+        result = self.call_tir(
+            self.multiply_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), 
"float32")
+        )
+        return self._convert_tvm_to_pytorch(result)
+
+    @T.prim_func
+    def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+        x = T.match_buffer(var_x, (5,), "float32")
+        y = T.match_buffer(var_y, (5,), "float32")
+        out = T.match_buffer(var_out, (5,), "float32")
+
+        for i in range(5):
+            out[i] = x[i] + y[i]
+
+    @T.prim_func
+    def multiply_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+        x = T.match_buffer(var_x, (5,), "float32")
+        y = T.match_buffer(var_y, (5,), "float32")
+        out = T.match_buffer(var_out, (5,), "float32")
+
+        for i in range(5):
+            out[i] = x[i] * y[i]
+
+    @R.function
+    def main_relax(
+        x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+    ) -> R.Tensor((5,), "float32"):
+        return R.add(x, y)
+
+
[email protected]_module
+class ComplexPyFuncModule(BasePyModule):
+    """Test complex Python logic with ML pipeline and error handling."""
+
+    @I.pyfunc
+    def ml_pipeline(self, input_data, model_params):
+        """Complex ML pipeline with data validation and error handling."""
+        # Data validation
+        if input_data is None or model_params is None:
+            raise ValueError("Inputs cannot be None")
+
+        try:
+            # Convert to TVM format
+            tvm_data = self._convert_pytorch_to_tvm(input_data)
+            tvm_params = self._convert_pytorch_to_tvm(model_params)
+
+            # Run ML inference
+            features = self.call_tir(
+                self.extract_features, [tvm_data], out_sinfo=R.Tensor((10,), 
"float32")
+            )
+
+            predictions = self.call_tir(
+                self.ml_inference, [features, tvm_params], 
out_sinfo=R.Tensor((5,), "float32")
+            )
+
+            # Post-process results
+            final_result = self.call_tir(
+                self.post_process, [predictions], out_sinfo=R.Tensor((5,), 
"float32")
+            )
+
+            return self._convert_tvm_to_pytorch(final_result)
+
+        except Exception as e:
+            self._log_error(f"ML pipeline failed: {e}")
+            return self._get_default_value()
+
+    @I.pyfunc
+    def data_preprocessing(self, raw_data):
+        """Data preprocessing with conditional logic."""
+        if hasattr(raw_data, "numpy"):
+            # Vectorized path for numpy-compatible data
+            data_np = raw_data.numpy()
+            processed = self._vectorized_preprocess(data_np)
+        else:
+            # Fallback path for other data types
+            processed = self._elementwise_preprocess(raw_data)
+
+        # Convert and return
+        tvm_processed = self._convert_pytorch_to_tvm(processed)
+        result = self.call_tir(
+            self.normalize_data, [tvm_processed], out_sinfo=R.Tensor((10,), 
"float32")
+        )
+        return self._convert_tvm_to_pytorch(result)
+
+    @T.prim_func
+    def extract_features(data: T.handle, features: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Data = T.match_buffer(data, (10,), "float32")
+        Features = T.match_buffer(features, (10,), "float32")
+
+        for i in range(10):
+            Features[i] = T.sqrt(Data[i])
+
+    @T.prim_func
+    def ml_inference(features: T.handle, params: T.handle, output: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Features = T.match_buffer(features, (10,), "float32")
+        Params = T.match_buffer(params, (10,), "float32")
+        Output = T.match_buffer(output, (5,), "float32")
+
+        for i in range(5):
+            Output[i] = Features[i] * Params[i] + Features[i + 5] * Params[i + 
5]
+
+    @T.prim_func
+    def post_process(predictions: T.handle, final: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Predictions = T.match_buffer(predictions, (5,), "float32")
+        Final = T.match_buffer(final, (5,), "float32")
+
+        for i in range(5):
+            Final[i] = T.max(Predictions[i], 0.0)
+
+    @T.prim_func
+    def normalize_data(data: T.handle, normalized: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Data = T.match_buffer(data, (10,), "float32")
+        Normalized = T.match_buffer(normalized, (10,), "float32")
+
+        for i in range(10):
+            Normalized[i] = Data[i] / 255.0
+
+
[email protected]_module
+class EdgeCasePyFuncModule(BasePyModule):
+    """Test edge cases and boundary conditions."""
+
+    @I.pyfunc
+    def empty_func(self):
+        """Empty function with no operations."""
+        pass
+
+    @I.pyfunc
+    def single_return(self, x):
+        """Function with immediate return."""
+        return x
+
+    @I.pyfunc
+    def nested_conditionals(self, data, threshold):
+        """Function with complex nested conditional logic."""
+        if data is None:
+            return None
+
+        if hasattr(data, "shape"):
+            if len(data.shape) == 1:
+                if data.shape[0] > threshold:
+                    return self._process_large_data(data)
+                else:
+                    return self._process_small_data(data)
+            elif len(data.shape) == 2:
+                return self._process_2d_data(data)
+            else:
+                return self._process_nd_data(data)
+        else:
+            return self._process_scalar_data(data)
+
+    @I.pyfunc
+    def loop_with_break(self, data, max_iter):
+        """Function with loop and break statement."""
+        result = []
+        for i, item in enumerate(data):
+            if i >= max_iter:
+                break
+            if item > 0:
+                result.append(item * 2)
+            else:
+                result.append(0)
+        return result
+
+    @T.prim_func
+    def dummy_tir(data: T.handle, output: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Data = T.match_buffer(data, (1,), "float32")
+        Output = T.match_buffer(output, (1,), "float32")
+        Output[0] = Data[0]
+
+
[email protected]_module
+class PerformancePyFuncModule(BasePyModule):
+    """Test performance optimization patterns."""
+
+    @I.pyfunc
+    def vectorized_operation(self, x, y):
+        """Vectorized operation with numpy fallback."""
+        try:
+            # Try vectorized operation first
+            if hasattr(x, "numpy") and hasattr(y, "numpy"):
+                x_np = x.numpy()
+                y_np = y.numpy()
+                result_np = x_np + y_np
+                return self._convert_numpy_to_pytorch(result_np)
+        except Exception:
+            pass
+
+        # Fallback to TVM processing
+        x_tvm = self._convert_pytorch_to_tvm(x)
+        y_tvm = self._convert_pytorch_to_tvm(y)
+        result = self.call_tir(
+            self.vectorized_add, [x_tvm, y_tvm], out_sinfo=R.Tensor((10,), 
"float32")
+        )
+        return self._convert_tvm_to_pytorch(result)
+
+    @I.pyfunc
+    def batch_processing(self, batch_data):
+        """Batch processing with memory optimization."""
+        batch_size = len(batch_data)
+        results = []
+
+        # Process in chunks to optimize memory usage
+        chunk_size = min(batch_size, 100)
+        for i in range(0, batch_size, chunk_size):
+            chunk = batch_data[i : i + chunk_size]
+            chunk_result = self._process_chunk(chunk)
+            results.extend(chunk_result)
+
+        return results
+
+    @I.pyfunc
+    def memory_efficient_transform(self, large_tensor):
+        """Memory-efficient tensor transformation."""
+        # Use in-place operations when possible
+        if hasattr(large_tensor, "requires_grad") and not 
large_tensor.requires_grad:
+            # In-place operation for efficiency
+            large_tensor.add_(1.0)
+            return large_tensor
+        else:
+            # Create new tensor if gradients are needed
+            return large_tensor + 1.0
+
+    @T.prim_func
+    def vectorized_add(a: T.handle, b: T.handle, c: T.handle):
+        T.func_attr({"tir.noalias": True})
+        A = T.match_buffer(a, (10,), "float32")
+        B = T.match_buffer(b, (10,), "float32")
+        C = T.match_buffer(c, (10,), "float32")
+
+        for i in range(10):
+            C[i] = A[i] + B[i]
+
+
[email protected]_module
+class IntegrationPyFuncModule(BasePyModule):
+    """Test integration with external libraries and complex workflows."""
+
+    @I.pyfunc
+    def sklearn_integration(self, input_data, scaler_params):
+        """Integration with scikit-learn preprocessing."""
+        try:
+            # Import sklearn components
+            from sklearn.preprocessing import StandardScaler
+            from sklearn.decomposition import PCA
+
+            # Create and fit scaler
+            scaler = StandardScaler()
+            if scaler_params is not None:
+                scaler.mean_ = scaler_params["mean"]
+                scaler.scale_ = scaler_params["scale"]
+            else:
+                scaler.fit(input_data)
+
+            # Transform data
+            scaled_data = scaler.transform(input_data)
+
+            # Apply PCA if needed
+            if input_data.shape[1] > 10:
+                pca = PCA(n_components=10)
+                reduced_data = pca.fit_transform(scaled_data)
+            else:
+                reduced_data = scaled_data
+
+            # Convert to TVM and process
+            tvm_data = self._convert_pytorch_to_tvm(reduced_data)
+            result = self.call_tir(
+                self.final_transform,
+                [tvm_data],
+                out_sinfo=R.Tensor((reduced_data.shape[0], 10), "float32"),
+            )
+
+            return self._convert_tvm_to_pytorch(result)
+
+        except ImportError:
+            # Fallback if sklearn is not available
+            return self._fallback_preprocessing(input_data)
+
+    @I.pyfunc
+    def multi_stage_pipeline(self, raw_input):
+        """Multi-stage processing pipeline."""
+        # Stage 1: Data cleaning
+        cleaned = self._clean_data(raw_input)
+
+        # Stage 2: Feature extraction
+        features = self._extract_features(cleaned)
+
+        # Stage 3: Model inference
+        predictions = self._run_inference(features)
+
+        # Stage 4: Post-processing
+        final_result = self._post_process_output(predictions)
+
+        return final_result
+
+    @T.prim_func
+    def final_transform(data: T.handle, output: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Data = T.match_buffer(data, (10, 10), "float32")
+        Output = T.match_buffer(output, (10, 10), "float32")
+
+        for i in range(10):
+            for j in range(10):
+                Output[i, j] = T.tanh(Data[i, j])
+
+
[email protected]_module
+class ErrorHandlingPyFuncModule(BasePyModule):
+    """Test comprehensive error handling and validation."""
+
+    @I.pyfunc
+    def robust_data_processing(self, input_data, config):
+        """Robust data processing with comprehensive error handling."""
+        try:
+            # Validate inputs
+            if not self._validate_inputs(input_data, config):
+                raise ValueError("Invalid input data or configuration")
+
+            # Check data types
+            if not self._check_data_types(input_data):
+                raise TypeError("Unsupported data types")
+
+            # Process data with retry logic
+            max_retries = config.get("max_retries", 3)
+            for attempt in range(max_retries):
+                try:
+                    result = self._process_with_validation(input_data, config)
+                    if self._validate_output(result):
+                        return result
+                    else:
+                        raise RuntimeError("Output validation failed")
+                except Exception as e:
+                    if attempt == max_retries - 1:
+                        raise
+                    self._log_warning(f"Attempt {attempt + 1} failed: {e}")
+                    continue
+
+        except Exception as e:
+            self._log_error(f"Data processing failed: {e}")
+            return self._get_safe_fallback(input_data, config)
+
+    @I.pyfunc
+    def graceful_degradation(self, primary_input, fallback_input):
+        """Function that gracefully degrades when primary path fails."""
+        try:
+            # Try primary processing path
+            result = self._primary_processing(primary_input)
+            return result
+        except Exception as e:
+            self._log_warning(f"Primary processing failed: {e}")
+
+            try:
+                # Try fallback path
+                result = self._fallback_processing(fallback_input)
+                return result
+            except Exception as e2:
+                self._log_error(f"Fallback processing also failed: {e2}")
+                # Return safe default
+                return self._get_safe_default()
+
+    @T.prim_func
+    def safe_transform(data: T.handle, output: T.handle):
+        T.func_attr({"tir.noalias": True})
+        Data = T.match_buffer(data, (5,), "float32")
+        Output = T.match_buffer(output, (5,), "float32")
+
+        for i in range(5):
+            # Safe operation that handles edge cases
+            if Data[i] > 0:
+                Output[i] = T.sqrt(Data[i])
+            else:
+                Output[i] = 0.0
+
+
+if __name__ == "__main__":
+    # This allows the file to be run directly for debugging
+    # In normal pytest usage, these classes are automatically tested by 
TVMScript
+    print("All test modules defined successfully!")
+    print("TVMScript will automatically validate these modules during 
testing.")
+
+    # Demo the printer functionality
+    print("\n" + "=" * 60)
+    print("DEMO: BasePyModule Printer Functionality")
+    print("=" * 60)
+
+    # Test the printer with SimplePyFuncModule
+    try:
+        ir_mod = SimplePyFuncModule
+        device = tvm.cpu()
+        module = BasePyModule(ir_mod, device)
+
+        print("\n1. Testing script() method:")
+        print("-" * 40)
+        script_output = module.script()
+        print(script_output[:500] + "..." if len(script_output) > 500 else 
script_output)
+
+        print("\n2. Testing show() method:")
+        print("-" * 40)
+        module.show()
+
+        print("\n3. Python functions found in pyfuncs:")
+        print("-" * 40)
+        if hasattr(ir_mod, "pyfuncs"):
+            for name, func in ir_mod.pyfuncs.items():
+                print(f"  - {name}: {func}")
+        else:
+            print("  No pyfuncs attribute found")
+
+    except Exception as e:
+        print(f"Demo failed: {e}")
+        print("This is expected for testing-only TVMScript code.")
+
+    # Run all tests using tvm.testing.main()
+    print("\n" + "=" * 60)
+    print("Running all tests with tvm.testing.main()...")
+    print("=" * 60)
+
+    import tvm.testing
+
+    tvm.testing.main()
+
+
+# Pytest test functions to verify the classes work correctly
+def test_simple_pyfunc_module_creation():
+    """Test that SimplePyFuncModule can be created."""
+    # Get the IRModule instance from the TVMScript decorated class
+    ir_mod = SimplePyFuncModule
+    device = tvm.cpu()
+
+    # Create BasePyModule instance
+    module = BasePyModule(ir_mod, device)
+    assert isinstance(module, BasePyModule)
+
+    # Note: Python functions are stored in pyfuncs, not as direct attributes
+    # We need to check if they exist in the IRModule's pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "add" in ir_mod.pyfuncs
+        assert "multiply" in ir_mod.pyfuncs
+
+    # Check that TIR functions exist
+    assert hasattr(module, "add_tir")
+    assert hasattr(module, "multiply_tir")
+
+    # Note: This particular TVMScript is for testing purpose only, and cannot 
compile
+    # Relax functions may not be available due to TVMScript compilation issues
+    print("Note: This TVMScript is for testing purpose only, and cannot 
compile")
+
+
+def test_complex_pyfunc_module_creation():
+    """Test that ComplexPyFuncModule can be created."""
+    ir_mod = ComplexPyFuncModule
+    device = tvm.cpu()
+
+    module = BasePyModule(ir_mod, device)
+    assert isinstance(module, BasePyModule)
+
+    # Check Python functions in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "ml_pipeline" in ir_mod.pyfuncs
+        assert "data_preprocessing" in ir_mod.pyfuncs
+
+    # Check TIR functions
+    assert hasattr(module, "extract_features")
+    assert hasattr(module, "ml_inference")
+    assert hasattr(module, "post_process")
+    assert hasattr(module, "normalize_data")
+
+
+def test_edge_case_pyfunc_module_creation():
+    """Test that EdgeCasePyFuncModule can be created."""
+    ir_mod = EdgeCasePyFuncModule
+    device = tvm.cpu()
+
+    module = BasePyModule(ir_mod, device)
+    assert isinstance(module, BasePyModule)
+
+    # Check Python functions in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "empty_func" in ir_mod.pyfuncs
+        assert "single_return" in ir_mod.pyfuncs
+        assert "nested_conditionals" in ir_mod.pyfuncs
+        assert "loop_with_break" in ir_mod.pyfuncs
+
+    # Check TIR function
+    assert hasattr(module, "dummy_tir")
+
+
+def test_performance_pyfunc_module_creation():
+    """Test that PerformancePyFuncModule can be created."""
+    ir_mod = PerformancePyFuncModule
+    device = tvm.cpu()
+
+    module = BasePyModule(ir_mod, device)
+    assert isinstance(module, BasePyModule)
+
+    # Check Python functions in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "vectorized_operation" in ir_mod.pyfuncs
+        assert "batch_processing" in ir_mod.pyfuncs
+        assert "memory_efficient_transform" in ir_mod.pyfuncs
+
+    # Check TIR function
+    assert hasattr(module, "vectorized_add")
+
+
+def test_integration_pyfunc_module_creation():
+    """Test that IntegrationPyFuncModule can be created."""
+    ir_mod = IntegrationPyFuncModule
+    device = tvm.cpu()
+
+    module = BasePyModule(ir_mod, device)
+    assert isinstance(module, BasePyModule)
+
+    # Check Python functions in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "sklearn_integration" in ir_mod.pyfuncs
+        assert "multi_stage_pipeline" in ir_mod.pyfuncs
+
+    # Check TIR function
+    assert hasattr(module, "final_transform")
+
+
+def test_error_handling_pyfunc_module_creation():
+    """Test that ErrorHandlingPyFuncModule can be created."""
+    ir_mod = ErrorHandlingPyFuncModule
+    device = tvm.cpu()
+
+    module = BasePyModule(ir_mod, device)
+    assert isinstance(module, BasePyModule)
+
+    # Check Python functions in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "robust_data_processing" in ir_mod.pyfuncs
+        assert "graceful_degradation" in ir_mod.pyfuncs
+
+    # Check TIR function
+    assert hasattr(module, "safe_transform")
+
+
+def test_all_modules_inherit_from_base():
+    """Test that all modules properly inherit from BasePyModule."""
+    modules = [
+        SimplePyFuncModule,
+        ComplexPyFuncModule,
+        EdgeCasePyFuncModule,
+        PerformancePyFuncModule,
+        IntegrationPyFuncModule,
+        ErrorHandlingPyFuncModule,
+    ]
+
+    device = tvm.cpu()
+    for ir_mod in modules:
+        module = BasePyModule(ir_mod, device)
+        assert isinstance(module, BasePyModule)
+        assert hasattr(module, "script")
+        assert hasattr(module, "show")
+
+
+def test_pyfunc_decorators():
+    """Test that all @I.pyfunc decorated functions are present."""
+    ir_mod = SimplePyFuncModule
+    device = tvm.cpu()
+    module = BasePyModule(ir_mod, device)
+
+    # Check that the functions exist in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "add" in ir_mod.pyfuncs
+        assert "multiply" in ir_mod.pyfuncs
+
+        # Get the actual function objects
+        add_func = ir_mod.pyfuncs["add"]
+        multiply_func = ir_mod.pyfuncs["multiply"]
+
+        # Check that they are callable
+        assert callable(add_func)
+        assert callable(multiply_func)
+
+        # Check function signatures
+        import inspect
+
+        add_sig = inspect.signature(add_func)
+        assert len(add_sig.parameters) == 3  # self, x, y
+
+        multiply_sig = inspect.signature(multiply_func)
+        assert len(multiply_sig.parameters) == 3  # self, x, y
+
+
+def test_tir_functions():
+    """Test that TIR functions are properly defined."""
+    ir_mod = SimplePyFuncModule
+    device = tvm.cpu()
+    module = BasePyModule(ir_mod, device)
+
+    # Check TIR function attributes
+    assert hasattr(module, "add_tir")
+    assert hasattr(module, "multiply_tir")
+
+    # These should be callable (though they're TIR functions)
+    assert callable(module.add_tir)
+    assert callable(module.multiply_tir)
+
+
+def test_relax_functions():
+    """Test that Relax functions are properly defined."""
+    ir_mod = SimplePyFuncModule
+    device = tvm.cpu()
+    module = BasePyModule(ir_mod, device)
+
+    # Note: This particular TVMScript is for testing purpose only, and cannot 
compile
+    # Relax functions may not be available due to TVMScript compilation issues
+    print("Note: This TVMScript is for testing purpose only, and cannot 
compile")
+
+    # We can still check that the module was created successfully
+    assert isinstance(module, BasePyModule)
+    assert hasattr(module, "script")
+    assert hasattr(module, "show")
+
+
+def test_module_docstrings():
+    """Test that all modules have proper docstrings."""
+    modules = [
+        SimplePyFuncModule,
+        ComplexPyFuncModule,
+        EdgeCasePyFuncModule,
+        PerformancePyFuncModule,
+        IntegrationPyFuncModule,
+        ErrorHandlingPyFuncModule,
+    ]
+
+    for module_class in modules:
+        # TVMScript decorator changes the class, so we check that it's callable
+        # and can create instances instead of checking docstrings
+        assert callable(module_class)
+        # We can't directly instantiate TVMScript decorated classes
+        # but we can create BasePyModule instances with them
+        device = tvm.cpu()
+        instance = BasePyModule(module_class, device)
+        assert isinstance(instance, BasePyModule)
+
+
+def test_python_function_complexity():
+    """Test that complex Python functions have the expected structure."""
+    ir_mod = ComplexPyFuncModule
+    device = tvm.cpu()
+    module = BasePyModule(ir_mod, device)
+
+    # Check that complex functions exist in pyfuncs
+    if hasattr(ir_mod, "pyfuncs"):
+        assert "ml_pipeline" in ir_mod.pyfuncs
+        assert "data_preprocessing" in ir_mod.pyfuncs
+
+        # Get the actual function objects
+        ml_func = ir_mod.pyfuncs["ml_pipeline"]
+        preprocess_func = ir_mod.pyfuncs["data_preprocessing"]
+
+        # These should be callable
+        assert callable(ml_func)
+        assert callable(preprocess_func)
+
+        # Check function signatures
+        import inspect
+
+        ml_sig = inspect.signature(ml_func)
+        assert len(ml_sig.parameters) == 3  # self, input_data, model_params
+
+        preprocess_sig = inspect.signature(preprocess_func)
+        assert len(preprocess_sig.parameters) == 2  # self, raw_data
+
+
+def test_script_and_show_methods():
+    """Test that script() and show() methods work correctly."""
+    ir_mod = SimplePyFuncModule
+    device = tvm.cpu()
+    module = BasePyModule(ir_mod, device)
+
+    # Test script() method
+    script_output = module.script()
+    assert isinstance(script_output, str)
+    assert len(script_output) > 0
+
+    # Test show() method
+    try:
+        module.show()
+        # If we get here, show() worked
+        assert True
+    except Exception as e:
+        # If show() fails, the feature is not working properly
+        pytest.fail(f"show() method failed: {e}")
+
+
+def test_python_functions_in_irmodule():
+    """Test that Python functions are properly stored in IRModule pyfuncs."""
+    ir_mod = SimplePyFuncModule
+    device = tvm.cpu()
+    module = BasePyModule(ir_mod, device)
+
+    # Check that pyfuncs attribute exists and contains our functions
+    if hasattr(ir_mod, "pyfuncs"):
+        pyfuncs = ir_mod.pyfuncs
+        assert isinstance(pyfuncs, dict)
+        assert "add" in pyfuncs
+        assert "multiply" in pyfuncs
+
+        # Check that the functions are callable
+        assert callable(pyfuncs["add"])
+        assert callable(pyfuncs["multiply"])
+
+        # Check function names
+        assert pyfuncs["add"].__name__ == "add"
+        assert pyfuncs["multiply"].__name__ == "multiply"
+    else:
+        pytest.fail("pyfuncs attribute not found in IRModule")

Reply via email to