This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7cbaa21131 [Relax] Fix int64 row index cast in GPU multinomial
sampling (#19902)
7cbaa21131 is described below
commit 7cbaa211310946271f0b15574c69486221a8ba71
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jun 29 21:30:44 2026 -0400
[Relax] Fix int64 row index cast in GPU multinomial sampling (#19902)
After the tirx refactor, a `T.let` binding no longer implicitly casts
its right-hand side to the annotated dtype. In
`gpu_multinomial_from_uniform`, `row_idx` is annotated `int64` but is
loaded from the `row_indices` buffer, whose dtype is the configurable
`sample_indices_dtype` and may be `int32`. Relying on the let annotation
to widen the value is no longer valid and yields a dtype mismatch.
Wrap the load in an explicit `T.Cast("int64", ...)` so the row index is
always converted to the annotated int64 type regardless of
`sample_indices_dtype`.
---
python/tvm/relax/backend/gpu_generic/sampling.py | 2 +-
tests/python/relax/test_backend_dispatch_sampling.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py
b/python/tvm/relax/backend/gpu_generic/sampling.py
index 54540cbaf7..487027ce7b 100644
--- a/python/tvm/relax/backend/gpu_generic/sampling.py
+++ b/python/tvm/relax/backend/gpu_generic/sampling.py
@@ -278,7 +278,7 @@ def gpu_multinomial_from_uniform(
step_iter = T.sblock_alloc_buffer((), "int32", scope="local")
for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
- row_idx: T.let[T.int64] = row_indices[bx, 0]
+ row_idx: T.let[T.int64] = T.Cast("int64", row_indices[bx, 0])
for ty in T.thread_binding(TY, thread="threadIdx.y"):
for tx in T.thread_binding(TX, thread="threadIdx.x"):
u: T.let[T.float32] = uniform_samples[bx, 0]
diff --git a/tests/python/relax/test_backend_dispatch_sampling.py
b/tests/python/relax/test_backend_dispatch_sampling.py
index 28778e3c52..e027a788ef 100644
--- a/tests/python/relax/test_backend_dispatch_sampling.py
+++ b/tests/python/relax/test_backend_dispatch_sampling.py
@@ -99,7 +99,7 @@ def test_dispatch_multinomial_from_uniform_gpu():
sample_id_local = T.sblock_alloc_buffer((), "int64", scope="local")
step_iter = T.sblock_alloc_buffer((), "int32", scope="local")
for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
- row_idx: T.let[T.int64] = row_indices[bx, 0]
+ row_idx: T.let[T.int64] = T.Cast("int64", row_indices[bx, 0])
for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for tx in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
u: T.let[T.float32] = uniform_samples[bx, 0]