This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7cbaa21131 [Relax] Fix int64 row index cast in GPU multinomial 
sampling (#19902)
7cbaa21131 is described below

commit 7cbaa211310946271f0b15574c69486221a8ba71
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jun 29 21:30:44 2026 -0400

    [Relax] Fix int64 row index cast in GPU multinomial sampling (#19902)
    
    After the tirx refactor, a `T.let` binding no longer implicitly casts
    its right-hand side to the annotated dtype. In
    `gpu_multinomial_from_uniform`, `row_idx` is annotated `int64` but is
    loaded from the `row_indices` buffer, whose dtype is the configurable
    `sample_indices_dtype` and may be `int32`. Relying on the let annotation
    to widen the value is no longer valid and yields a dtype mismatch.
    
    Wrap the load in an explicit `T.Cast("int64", ...)` so the row index is
    always converted to the annotated int64 type regardless of
    `sample_indices_dtype`.
---
 python/tvm/relax/backend/gpu_generic/sampling.py     | 2 +-
 tests/python/relax/test_backend_dispatch_sampling.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py 
b/python/tvm/relax/backend/gpu_generic/sampling.py
index 54540cbaf7..487027ce7b 100644
--- a/python/tvm/relax/backend/gpu_generic/sampling.py
+++ b/python/tvm/relax/backend/gpu_generic/sampling.py
@@ -278,7 +278,7 @@ def gpu_multinomial_from_uniform(
         step_iter = T.sblock_alloc_buffer((), "int32", scope="local")
 
         for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
-            row_idx: T.let[T.int64] = row_indices[bx, 0]
+            row_idx: T.let[T.int64] = T.Cast("int64", row_indices[bx, 0])
             for ty in T.thread_binding(TY, thread="threadIdx.y"):
                 for tx in T.thread_binding(TX, thread="threadIdx.x"):
                     u: T.let[T.float32] = uniform_samples[bx, 0]
diff --git a/tests/python/relax/test_backend_dispatch_sampling.py 
b/tests/python/relax/test_backend_dispatch_sampling.py
index 28778e3c52..e027a788ef 100644
--- a/tests/python/relax/test_backend_dispatch_sampling.py
+++ b/tests/python/relax/test_backend_dispatch_sampling.py
@@ -99,7 +99,7 @@ def test_dispatch_multinomial_from_uniform_gpu():
             sample_id_local = T.sblock_alloc_buffer((), "int64", scope="local")
             step_iter = T.sblock_alloc_buffer((), "int32", scope="local")
             for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
-                row_idx: T.let[T.int64] = row_indices[bx, 0]
+                row_idx: T.let[T.int64] = T.Cast("int64", row_indices[bx, 0])
                 for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                     for tx in T.thread_binding(T.int64(32), 
thread="threadIdx.x"):
                         u: T.let[T.float32] = uniform_samples[bx, 0]

Reply via email to