This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 887a9ca817 [Relax] Building TVMScript printer for IRModules with
Python functions (#18253)
887a9ca817 is described below
commit 887a9ca8172722bbf0156292c5d25a8822d66bda
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Sep 3 21:06:22 2025 -0400
[Relax] Building TVMScript printer for IRModules with Python functions
(#18253)
This PR implements TVMScript printer to format IRModules
containing `@I.pyfunc` decorated Python functions.
Example:
```
@I.ir_module
class MyModule(BasePyModule):
@I.pyfunc
def python_func(self, x, y):
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
result = self.call_tir(self.add_tir, [x_tvm, y_tvm],
out_sinfo=R.Tensor((5,), "float32"))
return self._convert_tvm_to_pytorch(result)
@T.prim_func
def add_tir(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, (5,), "float32")
B = T.match_buffer(b, (5,), "float32")
C = T.match_buffer(c, (5,), "float32")
for i in range(5):
C[i] = A[i] + B[i]
# Usage:
print(MyModule.script()) # Print formatted TVMScript
MyModule.show() # Display formatted output
```
---
python/tvm/relax/base_py_module.py | 129 +++-
tests/python/relax/test_base_py_module_printer.py | 760 ++++++++++++++++++++++
2 files changed, 888 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/base_py_module.py
b/python/tvm/relax/base_py_module.py
index 2ef17504c8..f463a84fc6 100644
--- a/python/tvm/relax/base_py_module.py
+++ b/python/tvm/relax/base_py_module.py
@@ -16,6 +16,8 @@
# under the License.
"""BasePyModule: Base class for IRModules with Python function support."""
+import inspect
+import os
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -369,7 +371,6 @@ class BasePyModule:
# Create a wrapper that handles both instance methods and static
functions
# pylint: disable=import-outside-toplevel
import functools
- import inspect
@functools.wraps(func)
def wrapper(*args, **kwargs):
@@ -383,3 +384,129 @@ class BasePyModule:
# Set the wrapper as an instance attribute
setattr(self, name, wrapper)
+
+ def script(
+ self,
+ *,
+ name: Optional[str] = None,
+ show_meta: bool = False,
+ ir_prefix: str = "I",
+ tir_prefix: str = "T",
+ relax_prefix: str = "R",
+ module_alias: str = "cls",
+ buffer_dtype: str = "float32",
+ int_dtype: str = "int32",
+ float_dtype: str = "void",
+ verbose_expr: bool = False,
+ indent_spaces: int = 4,
+ print_line_numbers: bool = False,
+ num_context_lines: int = -1,
+ syntax_sugar: bool = True,
+ show_object_address: bool = False,
+ show_all_struct_info: bool = True,
+ ) -> str:
+ """Print TVM IR into TVMScript text format with Python function
support.
+
+ This method extends the standard IRModule script() method to handle
+ Python functions stored in the IRModule's pyfuncs attribute.
+ """
+ # First get the standard IRModule script
+ base_script = self.ir_mod.script(
+ name=name,
+ show_meta=show_meta,
+ ir_prefix=ir_prefix,
+ tir_prefix=tir_prefix,
+ relax_prefix=relax_prefix,
+ module_alias=module_alias,
+ buffer_dtype=buffer_dtype,
+ int_dtype=int_dtype,
+ float_dtype=float_dtype,
+ verbose_expr=verbose_expr,
+ indent_spaces=indent_spaces,
+ print_line_numbers=print_line_numbers,
+ num_context_lines=num_context_lines,
+ syntax_sugar=syntax_sugar,
+ show_object_address=show_object_address,
+ show_all_struct_info=show_all_struct_info,
+ )
+
+ # If there are no Python functions, return the base script
+ if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs:
+ return base_script
+
+ # Insert Python functions into the script
+ return self._insert_python_functions(base_script, indent_spaces)
+
+ def _insert_python_functions(self, base_script: str, indent_spaces: int)
-> str:
+ """Insert Python functions into the TVMScript output."""
+ lines = base_script.split("\n")
+ result_lines = []
+
+ # Find the class definition line and insert Python functions after it
+ class_found = False
+ class_indent = 0
+
+ for line in lines:
+ result_lines.append(line)
+
+ # Look for class definition
+ if not class_found and line.strip().startswith("class "):
+ class_found = True
+ class_indent = len(line) - len(line.lstrip())
+
+ # Insert Python functions after the class definition
+ if hasattr(self.ir_mod, "pyfuncs") and self.ir_mod.pyfuncs:
+ for func_name, func in self.ir_mod.pyfuncs.items():
+ # Get the function source code
+ func_source = self._get_function_source(func)
+ if func_source:
+ # Format the function with proper indentation
+ formatted_func = self._format_python_function(
+ func_name, func_source, class_indent +
indent_spaces
+ )
+ result_lines.append(formatted_func)
+ result_lines.append("") # Add empty line for
separation
+
+ return "\n".join(result_lines)
+
+ def _get_function_source(self, func: callable) -> Optional[str]:
+ """Get the source code of a Python function."""
+ try:
+ source = inspect.getsource(func)
+ return source
+ except (OSError, TypeError):
+ # If we can't get the source, return None
+ return None
+
+ def _format_python_function(self, _func_name: str, func_source: str,
indent: int) -> str:
+ """Format a Python function with proper indentation for TVMScript."""
+ lines = func_source.split("\n")
+ formatted_lines = []
+
+ for line in lines:
+ # Skip the function definition line if it's already properly
indented
+ if line.strip().startswith("def ") or line.strip().startswith("@"):
+ # Keep decorators and function definition as is
+ formatted_lines.append(" " * indent + line.strip())
+ else:
+ # Add proper indentation for the function body
+ formatted_lines.append(" " * indent + line.strip())
+
+ return "\n".join(formatted_lines)
+
+ def show(
+ self, style: Optional[str] = None, black_format: Optional[bool] =
None, **kwargs
+ ) -> None:
+ """A sugar for print highlighted TVM script with Python function
support.
+
+ This method extends the standard IRModule show() method to handle
+ Python functions stored in the IRModule's pyfuncs attribute.
+ """
+ from tvm.script.highlight import cprint # pylint:
disable=import-outside-toplevel
+
+ if black_format is None:
+ env = os.environ.get("TVM_BLACK_FORMAT")
+ black_format = env and int(env)
+
+ script_content = self.script(**kwargs)
+ cprint(script_content, style=style, black_format=black_format)
diff --git a/tests/python/relax/test_base_py_module_printer.py
b/tests/python/relax/test_base_py_module_printer.py
new file mode 100644
index 0000000000..92c799f6cb
--- /dev/null
+++ b/tests/python/relax/test_base_py_module_printer.py
@@ -0,0 +1,760 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name, unused-argument
+
+import pytest
+import tvm
+from tvm.relax.base_py_module import BasePyModule
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.script import relax as R
+
+
[email protected]_module
+class SimplePyFuncModule(BasePyModule):
+ """Test simple Python functions with basic operations."""
+
+ @I.pyfunc
+ def add(self, x, y):
+ """Simple addition function."""
+ x_tvm = self._convert_pytorch_to_tvm(x)
+ y_tvm = self._convert_pytorch_to_tvm(y)
+ result = self.call_tir(self.add_tir, [x_tvm, y_tvm],
out_sinfo=R.Tensor((5,), "float32"))
+ return self._convert_tvm_to_pytorch(result)
+
+ @I.pyfunc
+ def multiply(self, x, y):
+ """Simple multiplication function."""
+ x_tvm = self._convert_pytorch_to_tvm(x)
+ y_tvm = self._convert_pytorch_to_tvm(y)
+ result = self.call_tir(
+ self.multiply_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,),
"float32")
+ )
+ return self._convert_tvm_to_pytorch(result)
+
+ @T.prim_func
+ def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+ x = T.match_buffer(var_x, (5,), "float32")
+ y = T.match_buffer(var_y, (5,), "float32")
+ out = T.match_buffer(var_out, (5,), "float32")
+
+ for i in range(5):
+ out[i] = x[i] + y[i]
+
+ @T.prim_func
+ def multiply_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
+ x = T.match_buffer(var_x, (5,), "float32")
+ y = T.match_buffer(var_y, (5,), "float32")
+ out = T.match_buffer(var_out, (5,), "float32")
+
+ for i in range(5):
+ out[i] = x[i] * y[i]
+
+ @R.function
+ def main_relax(
+ x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+ ) -> R.Tensor((5,), "float32"):
+ return R.add(x, y)
+
+
[email protected]_module
+class ComplexPyFuncModule(BasePyModule):
+ """Test complex Python logic with ML pipeline and error handling."""
+
+ @I.pyfunc
+ def ml_pipeline(self, input_data, model_params):
+ """Complex ML pipeline with data validation and error handling."""
+ # Data validation
+ if input_data is None or model_params is None:
+ raise ValueError("Inputs cannot be None")
+
+ try:
+ # Convert to TVM format
+ tvm_data = self._convert_pytorch_to_tvm(input_data)
+ tvm_params = self._convert_pytorch_to_tvm(model_params)
+
+ # Run ML inference
+ features = self.call_tir(
+ self.extract_features, [tvm_data], out_sinfo=R.Tensor((10,),
"float32")
+ )
+
+ predictions = self.call_tir(
+ self.ml_inference, [features, tvm_params],
out_sinfo=R.Tensor((5,), "float32")
+ )
+
+ # Post-process results
+ final_result = self.call_tir(
+ self.post_process, [predictions], out_sinfo=R.Tensor((5,),
"float32")
+ )
+
+ return self._convert_tvm_to_pytorch(final_result)
+
+ except Exception as e:
+ self._log_error(f"ML pipeline failed: {e}")
+ return self._get_default_value()
+
+ @I.pyfunc
+ def data_preprocessing(self, raw_data):
+ """Data preprocessing with conditional logic."""
+ if hasattr(raw_data, "numpy"):
+ # Vectorized path for numpy-compatible data
+ data_np = raw_data.numpy()
+ processed = self._vectorized_preprocess(data_np)
+ else:
+ # Fallback path for other data types
+ processed = self._elementwise_preprocess(raw_data)
+
+ # Convert and return
+ tvm_processed = self._convert_pytorch_to_tvm(processed)
+ result = self.call_tir(
+ self.normalize_data, [tvm_processed], out_sinfo=R.Tensor((10,),
"float32")
+ )
+ return self._convert_tvm_to_pytorch(result)
+
+ @T.prim_func
+ def extract_features(data: T.handle, features: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Data = T.match_buffer(data, (10,), "float32")
+ Features = T.match_buffer(features, (10,), "float32")
+
+ for i in range(10):
+ Features[i] = T.sqrt(Data[i])
+
+ @T.prim_func
+ def ml_inference(features: T.handle, params: T.handle, output: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Features = T.match_buffer(features, (10,), "float32")
+ Params = T.match_buffer(params, (10,), "float32")
+ Output = T.match_buffer(output, (5,), "float32")
+
+ for i in range(5):
+ Output[i] = Features[i] * Params[i] + Features[i + 5] * Params[i +
5]
+
+ @T.prim_func
+ def post_process(predictions: T.handle, final: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Predictions = T.match_buffer(predictions, (5,), "float32")
+ Final = T.match_buffer(final, (5,), "float32")
+
+ for i in range(5):
+ Final[i] = T.max(Predictions[i], 0.0)
+
+ @T.prim_func
+ def normalize_data(data: T.handle, normalized: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Data = T.match_buffer(data, (10,), "float32")
+ Normalized = T.match_buffer(normalized, (10,), "float32")
+
+ for i in range(10):
+ Normalized[i] = Data[i] / 255.0
+
+
[email protected]_module
+class EdgeCasePyFuncModule(BasePyModule):
+ """Test edge cases and boundary conditions."""
+
+ @I.pyfunc
+ def empty_func(self):
+ """Empty function with no operations."""
+ pass
+
+ @I.pyfunc
+ def single_return(self, x):
+ """Function with immediate return."""
+ return x
+
+ @I.pyfunc
+ def nested_conditionals(self, data, threshold):
+ """Function with complex nested conditional logic."""
+ if data is None:
+ return None
+
+ if hasattr(data, "shape"):
+ if len(data.shape) == 1:
+ if data.shape[0] > threshold:
+ return self._process_large_data(data)
+ else:
+ return self._process_small_data(data)
+ elif len(data.shape) == 2:
+ return self._process_2d_data(data)
+ else:
+ return self._process_nd_data(data)
+ else:
+ return self._process_scalar_data(data)
+
+ @I.pyfunc
+ def loop_with_break(self, data, max_iter):
+ """Function with loop and break statement."""
+ result = []
+ for i, item in enumerate(data):
+ if i >= max_iter:
+ break
+ if item > 0:
+ result.append(item * 2)
+ else:
+ result.append(0)
+ return result
+
+ @T.prim_func
+ def dummy_tir(data: T.handle, output: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Data = T.match_buffer(data, (1,), "float32")
+ Output = T.match_buffer(output, (1,), "float32")
+ Output[0] = Data[0]
+
+
[email protected]_module
+class PerformancePyFuncModule(BasePyModule):
+ """Test performance optimization patterns."""
+
+ @I.pyfunc
+ def vectorized_operation(self, x, y):
+ """Vectorized operation with numpy fallback."""
+ try:
+ # Try vectorized operation first
+ if hasattr(x, "numpy") and hasattr(y, "numpy"):
+ x_np = x.numpy()
+ y_np = y.numpy()
+ result_np = x_np + y_np
+ return self._convert_numpy_to_pytorch(result_np)
+ except Exception:
+ pass
+
+ # Fallback to TVM processing
+ x_tvm = self._convert_pytorch_to_tvm(x)
+ y_tvm = self._convert_pytorch_to_tvm(y)
+ result = self.call_tir(
+ self.vectorized_add, [x_tvm, y_tvm], out_sinfo=R.Tensor((10,),
"float32")
+ )
+ return self._convert_tvm_to_pytorch(result)
+
+ @I.pyfunc
+ def batch_processing(self, batch_data):
+ """Batch processing with memory optimization."""
+ batch_size = len(batch_data)
+ results = []
+
+ # Process in chunks to optimize memory usage
+ chunk_size = min(batch_size, 100)
+ for i in range(0, batch_size, chunk_size):
+ chunk = batch_data[i : i + chunk_size]
+ chunk_result = self._process_chunk(chunk)
+ results.extend(chunk_result)
+
+ return results
+
+ @I.pyfunc
+ def memory_efficient_transform(self, large_tensor):
+ """Memory-efficient tensor transformation."""
+ # Use in-place operations when possible
+ if hasattr(large_tensor, "requires_grad") and not
large_tensor.requires_grad:
+ # In-place operation for efficiency
+ large_tensor.add_(1.0)
+ return large_tensor
+ else:
+ # Create new tensor if gradients are needed
+ return large_tensor + 1.0
+
+ @T.prim_func
+ def vectorized_add(a: T.handle, b: T.handle, c: T.handle):
+ T.func_attr({"tir.noalias": True})
+ A = T.match_buffer(a, (10,), "float32")
+ B = T.match_buffer(b, (10,), "float32")
+ C = T.match_buffer(c, (10,), "float32")
+
+ for i in range(10):
+ C[i] = A[i] + B[i]
+
+
[email protected]_module
+class IntegrationPyFuncModule(BasePyModule):
+ """Test integration with external libraries and complex workflows."""
+
+ @I.pyfunc
+ def sklearn_integration(self, input_data, scaler_params):
+ """Integration with scikit-learn preprocessing."""
+ try:
+ # Import sklearn components
+ from sklearn.preprocessing import StandardScaler
+ from sklearn.decomposition import PCA
+
+ # Create and fit scaler
+ scaler = StandardScaler()
+ if scaler_params is not None:
+ scaler.mean_ = scaler_params["mean"]
+ scaler.scale_ = scaler_params["scale"]
+ else:
+ scaler.fit(input_data)
+
+ # Transform data
+ scaled_data = scaler.transform(input_data)
+
+ # Apply PCA if needed
+ if input_data.shape[1] > 10:
+ pca = PCA(n_components=10)
+ reduced_data = pca.fit_transform(scaled_data)
+ else:
+ reduced_data = scaled_data
+
+ # Convert to TVM and process
+ tvm_data = self._convert_pytorch_to_tvm(reduced_data)
+ result = self.call_tir(
+ self.final_transform,
+ [tvm_data],
+ out_sinfo=R.Tensor((reduced_data.shape[0], 10), "float32"),
+ )
+
+ return self._convert_tvm_to_pytorch(result)
+
+ except ImportError:
+ # Fallback if sklearn is not available
+ return self._fallback_preprocessing(input_data)
+
+ @I.pyfunc
+ def multi_stage_pipeline(self, raw_input):
+ """Multi-stage processing pipeline."""
+ # Stage 1: Data cleaning
+ cleaned = self._clean_data(raw_input)
+
+ # Stage 2: Feature extraction
+ features = self._extract_features(cleaned)
+
+ # Stage 3: Model inference
+ predictions = self._run_inference(features)
+
+ # Stage 4: Post-processing
+ final_result = self._post_process_output(predictions)
+
+ return final_result
+
+ @T.prim_func
+ def final_transform(data: T.handle, output: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Data = T.match_buffer(data, (10, 10), "float32")
+ Output = T.match_buffer(output, (10, 10), "float32")
+
+ for i in range(10):
+ for j in range(10):
+ Output[i, j] = T.tanh(Data[i, j])
+
+
[email protected]_module
+class ErrorHandlingPyFuncModule(BasePyModule):
+ """Test comprehensive error handling and validation."""
+
+ @I.pyfunc
+ def robust_data_processing(self, input_data, config):
+ """Robust data processing with comprehensive error handling."""
+ try:
+ # Validate inputs
+ if not self._validate_inputs(input_data, config):
+ raise ValueError("Invalid input data or configuration")
+
+ # Check data types
+ if not self._check_data_types(input_data):
+ raise TypeError("Unsupported data types")
+
+ # Process data with retry logic
+ max_retries = config.get("max_retries", 3)
+ for attempt in range(max_retries):
+ try:
+ result = self._process_with_validation(input_data, config)
+ if self._validate_output(result):
+ return result
+ else:
+ raise RuntimeError("Output validation failed")
+ except Exception as e:
+ if attempt == max_retries - 1:
+ raise
+ self._log_warning(f"Attempt {attempt + 1} failed: {e}")
+ continue
+
+ except Exception as e:
+ self._log_error(f"Data processing failed: {e}")
+ return self._get_safe_fallback(input_data, config)
+
+ @I.pyfunc
+ def graceful_degradation(self, primary_input, fallback_input):
+ """Function that gracefully degrades when primary path fails."""
+ try:
+ # Try primary processing path
+ result = self._primary_processing(primary_input)
+ return result
+ except Exception as e:
+ self._log_warning(f"Primary processing failed: {e}")
+
+ try:
+ # Try fallback path
+ result = self._fallback_processing(fallback_input)
+ return result
+ except Exception as e2:
+ self._log_error(f"Fallback processing also failed: {e2}")
+ # Return safe default
+ return self._get_safe_default()
+
+ @T.prim_func
+ def safe_transform(data: T.handle, output: T.handle):
+ T.func_attr({"tir.noalias": True})
+ Data = T.match_buffer(data, (5,), "float32")
+ Output = T.match_buffer(output, (5,), "float32")
+
+ for i in range(5):
+ # Safe operation that handles edge cases
+ if Data[i] > 0:
+ Output[i] = T.sqrt(Data[i])
+ else:
+ Output[i] = 0.0
+
+
+if __name__ == "__main__":
+ # This allows the file to be run directly for debugging
+ # In normal pytest usage, these classes are automatically tested by
TVMScript
+ print("All test modules defined successfully!")
+ print("TVMScript will automatically validate these modules during
testing.")
+
+ # Demo the printer functionality
+ print("\n" + "=" * 60)
+ print("DEMO: BasePyModule Printer Functionality")
+ print("=" * 60)
+
+ # Test the printer with SimplePyFuncModule
+ try:
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ print("\n1. Testing script() method:")
+ print("-" * 40)
+ script_output = module.script()
+ print(script_output[:500] + "..." if len(script_output) > 500 else
script_output)
+
+ print("\n2. Testing show() method:")
+ print("-" * 40)
+ module.show()
+
+ print("\n3. Python functions found in pyfuncs:")
+ print("-" * 40)
+ if hasattr(ir_mod, "pyfuncs"):
+ for name, func in ir_mod.pyfuncs.items():
+ print(f" - {name}: {func}")
+ else:
+ print(" No pyfuncs attribute found")
+
+ except Exception as e:
+ print(f"Demo failed: {e}")
+ print("This is expected for testing-only TVMScript code.")
+
+ # Run all tests using tvm.testing.main()
+ print("\n" + "=" * 60)
+ print("Running all tests with tvm.testing.main()...")
+ print("=" * 60)
+
+ import tvm.testing
+
+ tvm.testing.main()
+
+
+# Pytest test functions to verify the classes work correctly
+def test_simple_pyfunc_module_creation():
+ """Test that SimplePyFuncModule can be created."""
+ # Get the IRModule instance from the TVMScript decorated class
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+
+ # Create BasePyModule instance
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+
+ # Note: Python functions are stored in pyfuncs, not as direct attributes
+ # We need to check if they exist in the IRModule's pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "add" in ir_mod.pyfuncs
+ assert "multiply" in ir_mod.pyfuncs
+
+ # Check that TIR functions exist
+ assert hasattr(module, "add_tir")
+ assert hasattr(module, "multiply_tir")
+
+ # Note: This particular TVMScript is for testing purpose only, and cannot
compile
+ # Relax functions may not be available due to TVMScript compilation issues
+ print("Note: This TVMScript is for testing purpose only, and cannot
compile")
+
+
+def test_complex_pyfunc_module_creation():
+ """Test that ComplexPyFuncModule can be created."""
+ ir_mod = ComplexPyFuncModule
+ device = tvm.cpu()
+
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+
+ # Check Python functions in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "ml_pipeline" in ir_mod.pyfuncs
+ assert "data_preprocessing" in ir_mod.pyfuncs
+
+ # Check TIR functions
+ assert hasattr(module, "extract_features")
+ assert hasattr(module, "ml_inference")
+ assert hasattr(module, "post_process")
+ assert hasattr(module, "normalize_data")
+
+
+def test_edge_case_pyfunc_module_creation():
+ """Test that EdgeCasePyFuncModule can be created."""
+ ir_mod = EdgeCasePyFuncModule
+ device = tvm.cpu()
+
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+
+ # Check Python functions in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "empty_func" in ir_mod.pyfuncs
+ assert "single_return" in ir_mod.pyfuncs
+ assert "nested_conditionals" in ir_mod.pyfuncs
+ assert "loop_with_break" in ir_mod.pyfuncs
+
+ # Check TIR function
+ assert hasattr(module, "dummy_tir")
+
+
+def test_performance_pyfunc_module_creation():
+ """Test that PerformancePyFuncModule can be created."""
+ ir_mod = PerformancePyFuncModule
+ device = tvm.cpu()
+
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+
+ # Check Python functions in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "vectorized_operation" in ir_mod.pyfuncs
+ assert "batch_processing" in ir_mod.pyfuncs
+ assert "memory_efficient_transform" in ir_mod.pyfuncs
+
+ # Check TIR function
+ assert hasattr(module, "vectorized_add")
+
+
+def test_integration_pyfunc_module_creation():
+ """Test that IntegrationPyFuncModule can be created."""
+ ir_mod = IntegrationPyFuncModule
+ device = tvm.cpu()
+
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+
+ # Check Python functions in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "sklearn_integration" in ir_mod.pyfuncs
+ assert "multi_stage_pipeline" in ir_mod.pyfuncs
+
+ # Check TIR function
+ assert hasattr(module, "final_transform")
+
+
+def test_error_handling_pyfunc_module_creation():
+ """Test that ErrorHandlingPyFuncModule can be created."""
+ ir_mod = ErrorHandlingPyFuncModule
+ device = tvm.cpu()
+
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+
+ # Check Python functions in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "robust_data_processing" in ir_mod.pyfuncs
+ assert "graceful_degradation" in ir_mod.pyfuncs
+
+ # Check TIR function
+ assert hasattr(module, "safe_transform")
+
+
+def test_all_modules_inherit_from_base():
+ """Test that all modules properly inherit from BasePyModule."""
+ modules = [
+ SimplePyFuncModule,
+ ComplexPyFuncModule,
+ EdgeCasePyFuncModule,
+ PerformancePyFuncModule,
+ IntegrationPyFuncModule,
+ ErrorHandlingPyFuncModule,
+ ]
+
+ device = tvm.cpu()
+ for ir_mod in modules:
+ module = BasePyModule(ir_mod, device)
+ assert isinstance(module, BasePyModule)
+ assert hasattr(module, "script")
+ assert hasattr(module, "show")
+
+
+def test_pyfunc_decorators():
+ """Test that all @I.pyfunc decorated functions are present."""
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ # Check that the functions exist in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "add" in ir_mod.pyfuncs
+ assert "multiply" in ir_mod.pyfuncs
+
+ # Get the actual function objects
+ add_func = ir_mod.pyfuncs["add"]
+ multiply_func = ir_mod.pyfuncs["multiply"]
+
+ # Check that they are callable
+ assert callable(add_func)
+ assert callable(multiply_func)
+
+ # Check function signatures
+ import inspect
+
+ add_sig = inspect.signature(add_func)
+ assert len(add_sig.parameters) == 3 # self, x, y
+
+ multiply_sig = inspect.signature(multiply_func)
+ assert len(multiply_sig.parameters) == 3 # self, x, y
+
+
+def test_tir_functions():
+ """Test that TIR functions are properly defined."""
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ # Check TIR function attributes
+ assert hasattr(module, "add_tir")
+ assert hasattr(module, "multiply_tir")
+
+ # These should be callable (though they're TIR functions)
+ assert callable(module.add_tir)
+ assert callable(module.multiply_tir)
+
+
+def test_relax_functions():
+ """Test that Relax functions are properly defined."""
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ # Note: This particular TVMScript is for testing purpose only, and cannot
compile
+ # Relax functions may not be available due to TVMScript compilation issues
+ print("Note: This TVMScript is for testing purpose only, and cannot
compile")
+
+ # We can still check that the module was created successfully
+ assert isinstance(module, BasePyModule)
+ assert hasattr(module, "script")
+ assert hasattr(module, "show")
+
+
+def test_module_docstrings():
+ """Test that all modules have proper docstrings."""
+ modules = [
+ SimplePyFuncModule,
+ ComplexPyFuncModule,
+ EdgeCasePyFuncModule,
+ PerformancePyFuncModule,
+ IntegrationPyFuncModule,
+ ErrorHandlingPyFuncModule,
+ ]
+
+ for module_class in modules:
+ # TVMScript decorator changes the class, so we check that it's callable
+ # and can create instances instead of checking docstrings
+ assert callable(module_class)
+ # We can't directly instantiate TVMScript decorated classes
+ # but we can create BasePyModule instances with them
+ device = tvm.cpu()
+ instance = BasePyModule(module_class, device)
+ assert isinstance(instance, BasePyModule)
+
+
+def test_python_function_complexity():
+ """Test that complex Python functions have the expected structure."""
+ ir_mod = ComplexPyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ # Check that complex functions exist in pyfuncs
+ if hasattr(ir_mod, "pyfuncs"):
+ assert "ml_pipeline" in ir_mod.pyfuncs
+ assert "data_preprocessing" in ir_mod.pyfuncs
+
+ # Get the actual function objects
+ ml_func = ir_mod.pyfuncs["ml_pipeline"]
+ preprocess_func = ir_mod.pyfuncs["data_preprocessing"]
+
+ # These should be callable
+ assert callable(ml_func)
+ assert callable(preprocess_func)
+
+ # Check function signatures
+ import inspect
+
+ ml_sig = inspect.signature(ml_func)
+ assert len(ml_sig.parameters) == 3 # self, input_data, model_params
+
+ preprocess_sig = inspect.signature(preprocess_func)
+ assert len(preprocess_sig.parameters) == 2 # self, raw_data
+
+
+def test_script_and_show_methods():
+ """Test that script() and show() methods work correctly."""
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ # Test script() method
+ script_output = module.script()
+ assert isinstance(script_output, str)
+ assert len(script_output) > 0
+
+ # Test show() method
+ try:
+ module.show()
+ # If we get here, show() worked
+ assert True
+ except Exception as e:
+ # If show() fails, the feature is not working properly
+ pytest.fail(f"show() method failed: {e}")
+
+
+def test_python_functions_in_irmodule():
+ """Test that Python functions are properly stored in IRModule pyfuncs."""
+ ir_mod = SimplePyFuncModule
+ device = tvm.cpu()
+ module = BasePyModule(ir_mod, device)
+
+ # Check that pyfuncs attribute exists and contains our functions
+ if hasattr(ir_mod, "pyfuncs"):
+ pyfuncs = ir_mod.pyfuncs
+ assert isinstance(pyfuncs, dict)
+ assert "add" in pyfuncs
+ assert "multiply" in pyfuncs
+
+ # Check that the functions are callable
+ assert callable(pyfuncs["add"])
+ assert callable(pyfuncs["multiply"])
+
+ # Check function names
+ assert pyfuncs["add"].__name__ == "add"
+ assert pyfuncs["multiply"].__name__ == "multiply"
+ else:
+ pytest.fail("pyfuncs attribute not found in IRModule")