MasterJH5574 commented on code in PR #18253: URL: https://github.com/apache/tvm/pull/18253#discussion_r2320010509
########## tests/python/relax/test_base_py_module_printer.py: ########## @@ -0,0 +1,719 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name, unused-argument + +import pytest +import tvm +from tvm.relax.base_py_module import BasePyModule +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R + + [email protected]_module +class SimplePyFuncModule(BasePyModule): + """Test simple Python functions with basic operations.""" + + @I.pyfunc + def add(self, x, y): + """Simple addition function.""" + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir(self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32")) + return self._convert_tvm_to_pytorch(result) + + @I.pyfunc + def multiply(self, x, y): + """Simple multiplication function.""" + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir( + self.multiply_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + + for i in range(5): + out[i] = x[i] + y[i] + + @T.prim_func + def multiply_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + + for i in range(5): + out[i] = x[i] * y[i] + + @R.function + def main_relax( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.add(x, y) + + [email protected]_module +class ComplexPyFuncModule(BasePyModule): + """Test complex Python logic with ML pipeline and error handling.""" + + @I.pyfunc + def ml_pipeline(self, input_data, model_params): + """Complex ML pipeline with data validation and error handling.""" + # Data validation + if input_data is None or model_params is None: + raise ValueError("Inputs cannot be None") + + try: + # Convert to TVM format + tvm_data = self._convert_pytorch_to_tvm(input_data) + tvm_params = self._convert_pytorch_to_tvm(model_params) + + # Run ML inference + features = self.call_tir( + self.extract_features, [tvm_data], out_sinfo=R.Tensor((10,), "float32") + ) + + predictions = self.call_tir( + self.ml_inference, [features, tvm_params], out_sinfo=R.Tensor((5,), "float32") + ) + + # Post-process results + final_result = self.call_tir( + self.post_process, [predictions], out_sinfo=R.Tensor((5,), "float32") + ) + + return self._convert_tvm_to_pytorch(final_result) + + except Exception as e: + self._log_error(f"ML pipeline failed: {e}") + return self._get_default_value() + + @I.pyfunc + def data_preprocessing(self, raw_data): + """Data preprocessing with conditional logic.""" + if hasattr(raw_data, "numpy"): + # Vectorized path for numpy-compatible data + data_np = raw_data.numpy() + processed = self._vectorized_preprocess(data_np) + else: + # Fallback path for other data types + processed = self._elementwise_preprocess(raw_data) + + # Convert and return + tvm_processed = self._convert_pytorch_to_tvm(processed) + result = self.call_tir( + self.normalize_data, [tvm_processed], out_sinfo=R.Tensor((10,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + @T.prim_func + def extract_features(data: T.handle, features: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (10,), "float32") + Features = T.match_buffer(features, (10,), "float32") + + for i in range(10): + Features[i] = T.sqrt(Data[i]) + + @T.prim_func + def ml_inference(features: T.handle, params: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Features = T.match_buffer(features, (10,), "float32") + Params = T.match_buffer(params, (10,), "float32") + Output = T.match_buffer(output, (5,), "float32") + + for i in range(5): + Output[i] = Features[i] * Params[i] + Features[i + 5] * Params[i + 5] + + @T.prim_func + def post_process(predictions: T.handle, final: T.handle): + T.func_attr({"tir.noalias": True}) + Predictions = T.match_buffer(predictions, (5,), "float32") + Final = T.match_buffer(final, (5,), "float32") + + for i in range(5): + Final[i] = T.max(Predictions[i], 0.0) + + @T.prim_func + def normalize_data(data: T.handle, normalized: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (10,), "float32") + Normalized = T.match_buffer(normalized, (10,), "float32") + + for i in range(10): + Normalized[i] = Data[i] / 255.0 + + [email protected]_module +class EdgeCasePyFuncModule(BasePyModule): + """Test edge cases and boundary conditions.""" + + @I.pyfunc + def empty_func(self): + """Empty function with no operations.""" + pass + + @I.pyfunc + def single_return(self, x): + """Function with immediate return.""" + return x + + @I.pyfunc + def nested_conditionals(self, data, threshold): + """Function with complex nested conditional logic.""" + if data is None: + return None + + if hasattr(data, "shape"): + if len(data.shape) == 1: + if data.shape[0] > threshold: + return self._process_large_data(data) + else: + return self._process_small_data(data) + elif len(data.shape) == 2: + return self._process_2d_data(data) + else: + return self._process_nd_data(data) + else: + return self._process_scalar_data(data) + + @I.pyfunc + def loop_with_break(self, data, max_iter): + """Function with loop and break statement.""" + result = [] + for i, item in enumerate(data): + if i >= max_iter: + break + if item > 0: + result.append(item * 2) + else: + result.append(0) + return result + + @T.prim_func + def dummy_tir(data: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (1,), "float32") + Output = T.match_buffer(output, (1,), "float32") + Output[0] = Data[0] + + [email protected]_module +class PerformancePyFuncModule(BasePyModule): + """Test performance optimization patterns.""" + + @I.pyfunc + def vectorized_operation(self, x, y): + """Vectorized operation with numpy fallback.""" + try: + # Try vectorized operation first + if hasattr(x, "numpy") and hasattr(y, "numpy"): + x_np = x.numpy() + y_np = y.numpy() + result_np = x_np + y_np + return self._convert_numpy_to_pytorch(result_np) + except Exception: + pass + + # Fallback to TVM processing + x_tvm = self._convert_pytorch_to_tvm(x) + y_tvm = self._convert_pytorch_to_tvm(y) + result = self.call_tir( + self.vectorized_add, [x_tvm, y_tvm], out_sinfo=R.Tensor((10,), "float32") + ) + return self._convert_tvm_to_pytorch(result) + + @I.pyfunc + def batch_processing(self, batch_data): + """Batch processing with memory optimization.""" + batch_size = len(batch_data) + results = [] + + # Process in chunks to optimize memory usage + chunk_size = min(batch_size, 100) + for i in range(0, batch_size, chunk_size): + chunk = batch_data[i : i + chunk_size] + chunk_result = self._process_chunk(chunk) + results.extend(chunk_result) + + return results + + @I.pyfunc + def memory_efficient_transform(self, large_tensor): + """Memory-efficient tensor transformation.""" + # Use in-place operations when possible + if hasattr(large_tensor, "requires_grad") and not large_tensor.requires_grad: + # In-place operation for efficiency + large_tensor.add_(1.0) + return large_tensor + else: + # Create new tensor if gradients are needed + return large_tensor + 1.0 + + @T.prim_func + def vectorized_add(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"tir.noalias": True}) + A = T.match_buffer(a, (10,), "float32") + B = T.match_buffer(b, (10,), "float32") + C = T.match_buffer(c, (10,), "float32") + + for i in range(10): + C[i] = A[i] + B[i] + + [email protected]_module +class IntegrationPyFuncModule(BasePyModule): + """Test integration with external libraries and complex workflows.""" + + @I.pyfunc + def sklearn_integration(self, input_data, scaler_params): + """Integration with scikit-learn preprocessing.""" + try: + # Import sklearn components + from sklearn.preprocessing import StandardScaler + from sklearn.decomposition import PCA + + # Create and fit scaler + scaler = StandardScaler() + if scaler_params is not None: + scaler.mean_ = scaler_params["mean"] + scaler.scale_ = scaler_params["scale"] + else: + scaler.fit(input_data) + + # Transform data + scaled_data = scaler.transform(input_data) + + # Apply PCA if needed + if input_data.shape[1] > 10: + pca = PCA(n_components=10) + reduced_data = pca.fit_transform(scaled_data) + else: + reduced_data = scaled_data + + # Convert to TVM and process + tvm_data = self._convert_pytorch_to_tvm(reduced_data) + result = self.call_tir( + self.final_transform, + [tvm_data], + out_sinfo=R.Tensor((reduced_data.shape[0], 10), "float32"), + ) + + return self._convert_tvm_to_pytorch(result) + + except ImportError: + # Fallback if sklearn is not available + return self._fallback_preprocessing(input_data) + + @I.pyfunc + def multi_stage_pipeline(self, raw_input): + """Multi-stage processing pipeline.""" + # Stage 1: Data cleaning + cleaned = self._clean_data(raw_input) + + # Stage 2: Feature extraction + features = self._extract_features(cleaned) + + # Stage 3: Model inference + predictions = self._run_inference(features) + + # Stage 4: Post-processing + final_result = self._post_process_output(predictions) + + return final_result + + @T.prim_func + def final_transform(data: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (10, 10), "float32") + Output = T.match_buffer(output, (10, 10), "float32") + + for i in range(10): + for j in range(10): + Output[i, j] = T.tanh(Data[i, j]) + + [email protected]_module +class ErrorHandlingPyFuncModule(BasePyModule): + """Test comprehensive error handling and validation.""" + + @I.pyfunc + def robust_data_processing(self, input_data, config): + """Robust data processing with comprehensive error handling.""" + try: + # Validate inputs + if not self._validate_inputs(input_data, config): + raise ValueError("Invalid input data or configuration") + + # Check data types + if not self._check_data_types(input_data): + raise TypeError("Unsupported data types") + + # Process data with retry logic + max_retries = config.get("max_retries", 3) + for attempt in range(max_retries): + try: + result = self._process_with_validation(input_data, config) + if self._validate_output(result): + return result + else: + raise RuntimeError("Output validation failed") + except Exception as e: + if attempt == max_retries - 1: + raise + self._log_warning(f"Attempt {attempt + 1} failed: {e}") + continue + + except Exception as e: + self._log_error(f"Data processing failed: {e}") + return self._get_safe_fallback(input_data, config) + + @I.pyfunc + def graceful_degradation(self, primary_input, fallback_input): + """Function that gracefully degrades when primary path fails.""" + try: + # Try primary processing path + result = self._primary_processing(primary_input) + return result + except Exception as e: + self._log_warning(f"Primary processing failed: {e}") + + try: + # Try fallback path + result = self._fallback_processing(fallback_input) + return result + except Exception as e2: + self._log_error(f"Fallback processing also failed: {e2}") + # Return safe default + return self._get_safe_default() + + @T.prim_func + def safe_transform(data: T.handle, output: T.handle): + T.func_attr({"tir.noalias": True}) + Data = T.match_buffer(data, (5,), "float32") + Output = T.match_buffer(output, (5,), "float32") + + for i in range(5): + # Safe operation that handles edge cases + if Data[i] > 0: + Output[i] = T.sqrt(Data[i]) + else: + Output[i] = 0.0 + + +if __name__ == "__main__": + # This allows the file to be run directly for debugging + # In normal pytest usage, these classes are automatically tested by TVMScript + print("All test modules defined successfully!") + print("TVMScript will automatically validate these modules during testing.") Review Comment: I see. Is it possible to make `__main__` also run them? Given this is a printer test. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
