Kathryn-cat commented on code in PR #18055:
URL: https://github.com/apache/tvm/pull/18055#discussion_r2147023147
##########
tests/python/codegen/test_target_codegen_cuda.py:
##########
@@ -746,5 +747,28 @@ def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,),
"uint8")) -> None:
tvm.compile(func, target="cuda")
[email protected]_cuda
+def test_cuda_device_func_call():
+ @I.ir_module
+ class Module:
+ @T.prim_func(private=True)
+ def add(a: T.float32, b: T.float32) -> T.float32:
+ return a + b
+
+ @T.prim_func
+ def main(
Review Comment:
I think right now it's quite implicit to determine which function is the
kernel function and which function is the device function. It might be clearer
if we can mark `@T.prim_func(kind="device")` explicitly for device functions.
Moreover, we can enhance by adding a test case where all functions are not
wrapped by Module, and instead of compiling the Module, we compile the kernel
function directly.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]