Hzfengsy opened a new pull request, #18055:
URL: https://github.com/apache/tvm/pull/18055
This commit adds support for CUDA device function calls by:
1. Modifying the calling convention handling in CUDA codegen to support both
device kernel launches and device function calls
2. Updating the function signature printing to emit appropriate CUDA
attributes (__global__ vs __device__) based on calling convention
3. Adding a test case demonstrating device function calls
4. Fixing target handling in split_host_device_mods to properly handle
device function dictionaries
5. Adding a safety check for global symbol extraction
The changes enable proper compilation and execution of CUDA device functions
that can be called from CUDA kernels.
Example:
```python
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.float32, b: T.float32) -> T.float32:
return a + b
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
B: T.Buffer((1024, 1024), "float32"),
C: T.Buffer((1024, 1024), "float32"),
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]