This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 731955b refactor: Get `compiled_kernel` from Triton Call Directly
(#423)
731955b is described below
commit 731955b304d0f4445b981e4bb9417e49209c6c37
Author: Jinjie Liu <[email protected]>
AuthorDate: Tue Feb 10 03:29:37 2026 +0800
refactor: Get `compiled_kernel` from Triton Call Directly (#423)
Triton call will return the `CompiledKernel` directly to get the cubin.
The previous code searches the cubin by cache, which is unnecessary and
not robust.
This PR refactors that by taking compiled_kernel directly from the
Triton kernel call.
Signed-off-by: Jinjie Liu <[email protected]>
---
examples/cubin_launcher/benchmark_overhead.py | 8 +-------
examples/cubin_launcher/example_triton_cubin.py | 8 +-------
2 files changed, 2 insertions(+), 14 deletions(-)
diff --git a/examples/cubin_launcher/benchmark_overhead.py
b/examples/cubin_launcher/benchmark_overhead.py
index efc82b2..230c013 100644
--- a/examples/cubin_launcher/benchmark_overhead.py
+++ b/examples/cubin_launcher/benchmark_overhead.py
@@ -132,13 +132,7 @@ def generate_cubin() -> bytes:
a_dummy = torch.empty(n, dtype=torch.float32, device="cuda")
b_dummy = torch.empty(n, dtype=torch.float32, device="cuda")
c_dummy = torch.empty(n, dtype=torch.float32, device="cuda")
- empty_kernel[1,](a_dummy, b_dummy, c_dummy, n)
-
- # Extract compiled CUBIN from the device cache
- device_caches = empty_kernel.device_caches
- device_id = next(iter(device_caches.keys()))
- cache_tuple = device_caches[device_id]
- compiled_kernel = next(iter(cache_tuple[0].values()))
+ compiled_kernel = empty_kernel[1,](a_dummy, b_dummy, c_dummy, n)
# Get CUBIN bytes
cubin_bytes = compiled_kernel.kernel
diff --git a/examples/cubin_launcher/example_triton_cubin.py
b/examples/cubin_launcher/example_triton_cubin.py
index b127cd2..90243b7 100644
--- a/examples/cubin_launcher/example_triton_cubin.py
+++ b/examples/cubin_launcher/example_triton_cubin.py
@@ -67,13 +67,7 @@ def generate_cubin() -> bytes:
# Trigger kernel compilation by doing a dummy call
x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda")
y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda")
- square_kernel[1, 1](x_dummy, y_dummy, 1024)
-
- # Extract compiled CUBIN from the device cache
- device_caches = square_kernel.device_caches
- device_id = next(iter(device_caches.keys()))
- cache_tuple = device_caches[device_id]
- compiled_kernel = next(iter(cache_tuple[0].values()))
+ compiled_kernel = square_kernel[1, 1](x_dummy, y_dummy, 1024)
# Get CUBIN bytes
cubin_bytes = compiled_kernel.kernel