MasterJH5574 commented on code in PR #18313:
URL: https://github.com/apache/tvm/pull/18313#discussion_r2361446722
##########
tests/python/relax/test_base_py_module_printer.py:
##########
@@ -758,3 +758,213 @@ def test_python_functions_in_irmodule():
assert pyfuncs["multiply"].__name__ == "multiply"
else:
pytest.fail("pyfuncs attribute not found in IRModule")
+
+
+def test_call_py_func_operator():
+ """Test R.call_py_func operator functionality."""
+ import torch
+
+ @I.ir_module
+ class CallPyFuncTestModule(BasePyModule):
+ """Test module with call_py_func usage."""
+
+ @I.pyfunc
+ def pytorch_add(self, x, y):
+ """Simple PyTorch addition."""
+ return x + y
+
+ @I.pyfunc
+ def pytorch_multiply(self, x, y):
+ """Simple PyTorch multiplication."""
+ return x * y
+
+ @R.function
+ def test_call_py_func(
+ x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+ ) -> R.Tensor((5,), "float32"):
+ # Test calling Python function from Relax
+ result = R.call_py_func("pytorch_add", (x, y),
out_sinfo=R.Tensor((5,), "float32"))
+ return result
+
+ @R.function
+ def test_call_py_func_chain(
+ x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
+ ) -> R.Tensor((5,), "float32"):
+ # First call
+ intermediate = R.call_py_func(
+ "pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32")
+ )
+ # Second call
+ result = R.call_py_func(
+ "pytorch_multiply", (intermediate, y),
out_sinfo=R.Tensor((5,), "float32")
+ )
+ return result
+
+ # Test basic functionality
+ device = tvm.cpu()
+ module = CallPyFuncTestModule(device)
+
+ # Create test tensors
+ x = torch.randn(5, dtype=torch.float32)
+ y = torch.randn(5, dtype=torch.float32)
+
+ # Test direct Python function calls
+ expected_add = x + y
+ expected_multiply = x * y
+
+ # Test through BasePyModule
+ result_add = module.call_py_func("pytorch_add", [x, y])
+ result_multiply = module.call_py_func("pytorch_multiply", [x, y])
+
+ assert torch.allclose(result_add, expected_add, atol=1e-5)
+ assert torch.allclose(result_multiply, expected_multiply, atol=1e-5)
+
+ # Test that the module has the pyfuncs
+ assert hasattr(module, "pyfuncs")
+ assert "pytorch_add" in module.pyfuncs
+ assert "pytorch_multiply" in module.pyfuncs
+
+
+def test_call_py_func_validation():
+ """Test call_py_func validation and error handling."""
+ import torch
+
+ @I.ir_module
+ class ValidationTestModule(BasePyModule):
+ """Test module for validation."""
+
+ @I.pyfunc
+ def valid_func(self, x):
+ """Valid Python function."""
+ return x * 2
+
+ @R.function
+ def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,),
"float32"):
+ # This should cause a validation error
+ result = R.call_py_func("non_existent_func", (x,),
out_sinfo=R.Tensor((5,), "float32"))
+ return result
+
+ device = tvm.cpu()
+ module = ValidationTestModule(device)
+
+ # Test that calling non-existent function raises error
+ x = torch.randn(5, dtype=torch.float32)
+
+ with pytest.raises(ValueError, match="Python function 'non_existent_func'
not found"):
+ module.call_py_func("non_existent_func", [x])
+
+
+def test_call_py_func_in_relax_function():
Review Comment:
In this case I think the “complex” one is sufficient.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]