This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fe5a350f47 [Relax] add sample_indices in sampling (#16675)
fe5a350f47 is described below

commit fe5a350f47fd5b15f8a8a8eeb33b4b313f5c35a9
Author: Yong Wu <[email protected]>
AuthorDate: Tue Mar 5 09:49:49 2024 -0800

    [Relax] add sample_indices in sampling (#16675)
---
 python/tvm/relax/frontend/nn/op.py        | 134 +++++++++++++++++++++++-------
 tests/python/relax/test_frontend_nn_op.py | 121 +++++++++++++--------------
 2 files changed, 163 insertions(+), 92 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 6944fc8535..ae880190ad 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -2057,7 +2057,12 @@ def cumsum(
     return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
 
 
-def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str 
= "int64"):
+def multinomial_from_uniform(
+    prob: Tensor,
+    uniform_sample: Tensor,
+    sample_indices: Optional[Tensor] = None,
+    dtype: str = "int64",
+):
     """Returns a tensor where each row contains the index sampled from the 
multinomial
     probability distribution located in the corresponding row of tensor prob.
 
@@ -2075,13 +2080,25 @@ def multinomial_from_uniform(prob: Tensor, 
uniform_sample: Tensor, dtype: str =
         The sum of values in each row is 1, forming a valid distribution.
 
     uniform_sample : Tensor
-        The uniformly sampled 2-D tensor with the shape (batch, 1).
+        The uniformly sampled 2-D tensor with the shape (n, 1).
         Values range from 0 to 1, indicating probabilities sampled uniformly.
 
+    sample_indices : Optional[Tensor]
+        The 2-D tensor with the shape [n, 1], which indicates the specific
+        probability distribution to sample from. The value of sample_indices[i]
+        determines that the ith token should be sampled from the 
sample_indices[i]th
+        probability distribution. For instance, if there are 3 distinct 
probability
+        distributions and the requirement is to sample 2, 3, and 4 tokens from 
each,
+        then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
+
+    dtype : str
+        The data type of output tensor.
+
+
     Returns
     -------
     result : Tensor
-        The computed tensor with shape (batch, 1).
+        The computed tensor with shape (n, 1).
 
     Examples
     --------
@@ -2089,29 +2106,52 @@ def multinomial_from_uniform(prob: Tensor, 
uniform_sample: Tensor, dtype: str =
 
         prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
         usample = [[0.4], [0.9]]
+        sample_indices = [[0], [1]]
 
         multinomial_from_uniform(prob, usample)
         -> [[1], [2]]
+        multinomial_from_uniform(prob, usample, sample_indices)
+        -> [[1], [2]]
     """
     prob_dtype = prob.dtype
     sample_dtype = uniform_sample.dtype
-    batch = prob.shape[0]
+    out_batch = uniform_sample.shape[0]
+
+    if sample_indices is not None:
+        assert (
+            sample_indices.shape == uniform_sample.shape
+        ), "The shape of sample_indices must match the shape of 
uniform_sample."
+    else:
+        assert (
+            prob.shape[0] == uniform_sample.shape[0]
+        ), "Number of samples must match the number of probability 
distributions."
+        sample_indices = 
Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1))
+
+    sample_indices_dtype = sample_indices.dtype
 
     @T.prim_func(private=True)
-    def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+    def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
         batch, vocab_size = T.int64(), T.int64()
         prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
-        usample = T.match_buffer(B, (batch, 1), sample_dtype)
-        output_index = T.match_buffer(C, (batch, 1), dtype)
+        out_batch = T.int64()
+        usample = T.match_buffer(B, (out_batch, 1), sample_dtype)
+        sample_indices = T.match_buffer(C, (out_batch, 1), 
sample_indices_dtype)
+        output_index = T.match_buffer(D, (out_batch, 1), dtype)
 
-        for ax0, ax1 in T.grid(batch, vocab_size):
+        for ax0, ax1 in T.grid(out_batch, vocab_size):
             with T.block("T_get_sample_index"):
                 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                 T.writes(output_index[v_ax0, 0])
-                if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 
1 == vocab_size:
+                if (
+                    usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, 
T.int64(0)], v_ax1]
+                    or v_ax1 + 1 == vocab_size
+                ):
                     if v_ax1 == 0:
                         output_index[v_ax0, 0] = 0
-                    elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
+                    elif (
+                        usample[v_ax0, T.int64(0)]
+                        >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
+                    ):
                         output_index[v_ax0, 0] = v_ax1
 
     cumsum_prob = cumsum(prob, axis=1, exclusive=False)
@@ -2119,13 +2159,18 @@ def multinomial_from_uniform(prob: Tensor, 
uniform_sample: Tensor, dtype: str =
     return tensor_ir_op(
         _get_sample_index,
         "get_sample_index",
-        args=[cumsum_prob, uniform_sample],
-        out=Tensor.placeholder([batch, 1], dtype),
+        args=[cumsum_prob, uniform_sample, sample_indices],
+        out=Tensor.placeholder([out_batch, 1], dtype),
     )
 
 
 def sample_top_p_top_k_from_sorted_prob(
-    sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, 
uniform_sample: Tensor
+    sorted_prob: Tensor,
+    sorted_index: Tensor,
+    top_p: Tensor,
+    top_k: Tensor,
+    uniform_sample: Tensor,
+    sample_indices: Optional[Tensor] = None,
 ):
     """Samples indices from a sorted probability tensor based on top_p and 
top_k criteria.
 
@@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob(
         to consider for top-k sampling.
 
     uniform_sample : Tensor
-        Uniformly sampled values with shape (batch, 1) are used to select the 
output indices.
+        Uniformly sampled values with shape (n, 1) are used to select the 
output indices.
+
+    sample_indices : Optional[Tensor]
+        The 2-D tensor with the shape [n, 1], which indicates the specific
+        probability distribution to sample from. The value of sample_indices[i]
+        determines that the ith token should be sampled from the 
sample_indices[i]th
+        probability distribution. For instance, if there are 3 distinct 
probability
+        distributions and the requirement is to sample 2, 3, and 4 tokens from 
each,
+        then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
 
     Returns
     -------
     result : Tensor
-        The selected indices with shape (batch, 1).
+        The selected indices with shape (n, 1).
 
     Examples
     --------
@@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob(
         top_p = [[0.6],[0.9]]
         top_k = [[3],[2]]
         uniform_sample = [[0.5], [0.6]]
+        sample_indices = [[0], [1]]
 
         sample_top_p_top_k_from_sorted_prob(
-            sorted_prob, sorted_index,top_p, top_k, uniform_sample)
+            sorted_prob, sorted_index,top_p, top_k, uniform_sample, 
sample_indices)
         -> [2, 0]
 
     """
     prob_dtype = sorted_prob.dtype
     index_dtype = sorted_index.dtype
-    batch = sorted_prob.shape[0]
+    prob_batch = sorted_prob.shape[0]
+    out_batch = uniform_sample.shape[0]
+
+    if sample_indices is not None:
+        assert (
+            sample_indices.shape == uniform_sample.shape
+        ), "The shape of sample_indices must match the shape of 
uniform_sample."
+    else:
+        assert (
+            sorted_prob.shape[0] == uniform_sample.shape[0]
+        ), "Number of samples must match the number of probability 
distributions."
+        sample_indices = Tensor.from_const(
+            np.arange(out_batch).reshape(out_batch, 1).astype(np.int64)
+        )
+        print("sample_indices: ", sample_indices)
+    sample_indices_dtype = sample_indices.dtype
 
     def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
         return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
@@ -2204,27 +2273,34 @@ def sample_top_p_top_k_from_sorted_prob(
                         renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]
 
     @T.prim_func(private=True)
-    def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle, E: T.handle):
+    def _get_index_from_sorted(
+        A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: 
T.handle
+    ):
         batch, vocab_size = T.int64(), T.int64()
+        out_batch = T.int64()
         cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
-        renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
-        usample = T.match_buffer(C, (batch, 1), prob_dtype)
-        indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
-        output_index = T.match_buffer(E, (batch, 1), index_dtype)
+        indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
+        renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
+        usample = T.match_buffer(D, (out_batch, 1), prob_dtype)
+        sample_indices = T.match_buffer(E, (out_batch, 1), 
sample_indices_dtype)
+        output_index = T.match_buffer(F, (out_batch, 1), index_dtype)
 
-        for ax0, ax1 in T.grid(batch, vocab_size):
+        for ax0, ax1 in T.grid(out_batch, vocab_size):
             with T.block("T_get_index_from_sorted"):
                 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                 T.writes(output_index[v_ax0, 0])
                 if (
-                    usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / 
renorm_prob[v_ax0, 0]
+                    usample[v_ax0, T.int64(0)]
+                    < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1]
+                    / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
                     or v_ax1 + 1 == vocab_size
                 ):
                     if v_ax1 == 0:
                         output_index[v_ax0, 0] = indices[v_ax0, 0]
                     elif (
                         usample[v_ax0, T.int64(0)]
-                        >= cumsum_sorted[v_ax0, v_ax1 - 1] / 
renorm_prob[v_ax0, 0]
+                        >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], 
v_ax1 - 1]
+                        / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
                     ):
                         output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
 
@@ -2235,7 +2311,7 @@ def sample_top_p_top_k_from_sorted_prob(
         "get_renorm_prob",
         args=[cumsum_sorted, top_p, top_k],
         out=Tensor.placeholder(
-            [batch, 1],
+            [prob_batch, 1],
             prob_dtype,
         ),
     )
@@ -2243,8 +2319,8 @@ def sample_top_p_top_k_from_sorted_prob(
     out_index_in_sorted = tensor_ir_op(
         _get_index_from_sorted,
         "get_index_from_sorted",
-        args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
-        out=Tensor.placeholder([batch, 1], index_dtype),
+        args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample, 
sample_indices],
+        out=Tensor.placeholder([out_batch, 1], index_dtype),
     )
     return out_index_in_sorted
 
@@ -2293,7 +2369,7 @@ def renormalize_top_p_top_k_prob(prob, sorted_prob, 
top_p, top_k):
         top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
         cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
         for ax0, ax1 in T.grid(batch, vocab_size):
-            with T.block("T_get_renorm_prob"):
+            with T.block("T_get_renorm_cutoff"):
                 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                 if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
                     cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 3457989a55..0d579163cd 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -846,34 +846,36 @@ def test_empty():
 @tvm.testing.requires_gpu
 def test_multinomial_from_uniform():
 
-    prob_shape = (4, 5)
-    sample_shape = (4, 1)
+    prob_shape = (3, 5)
+    sample_shape = (6, 1)
 
     class Model(Module):
-        def foo(self, prob: Tensor, uniform_sample: Tensor):
-            z0 = op.multinomial_from_uniform(prob, uniform_sample)
+        def foo(self, prob: Tensor, uniform_sample: Tensor, sample_indices: 
Tensor):
+            z0 = op.multinomial_from_uniform(prob, uniform_sample, 
sample_indices)
             return z0
 
     # fmt: off
     @I.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+        def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle):
             batch, vocab_size = T.int64(), T.int64()
             prob = T.match_buffer(A, (batch, vocab_size))
-            usample = T.match_buffer(B, (batch, 1))
-            output_index = T.match_buffer(C, (batch, 1), "int64")
+            out_batch = T.int64()
+            usample = T.match_buffer(B, (out_batch, 1))
+            sample_indices = T.match_buffer(C, (out_batch, 1), "int64")
+            output_index = T.match_buffer(D, (out_batch, 1), "int64")
             # with T.block("root"):
-            for ax0, ax1 in T.grid(batch, vocab_size):
+            for ax0, ax1 in T.grid(out_batch, vocab_size):
                 with T.block("T_get_sample_index"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(usample[v_ax0, T.int64(0)], prob[v_ax0, v_ax1 - 
T.int64(1):v_ax1 - T.int64(1) + T.int64(2)])
+                    T.reads(usample[v_ax0, T.int64(0)], 
prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + 
T.int64(2)], sample_indices[v_ax0, T.int64(0)])
                     T.writes(output_index[v_ax0, 0])
-                    if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or 
v_ax1 + T.int64(1) == vocab_size:
+                    if usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, 
T.int64(0)], v_ax1] or v_ax1 + T.int64(1) == vocab_size:
                         if v_ax1 == T.int64(0):
                             output_index[v_ax0, 0] = T.int64(0)
                         else:
-                            if usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 
- T.int64(1)]:
+                            if usample[v_ax0, T.int64(0)] >= 
prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)]:
                                 output_index[v_ax0, 0] = v_ax1
 
         @R.function
@@ -886,13 +888,13 @@ def test_multinomial_from_uniform():
             return gv
 
         @R.function
-        def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample: 
R.Tensor((4, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((4, 1), 
dtype="int64"), R.Tuple(R.Object)):
-            R.func_attr({"num_input": 3})
+        def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: 
R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), 
dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((6, 1), dtype="int64"), 
R.Tuple(R.Object)):
+            R.func_attr({"num_input": 4})
             cls = Expected
             with R.dataflow():
-                cumsum: R.Tensor((4, 5), dtype="float32") = R.cumsum(prob, 
axis=1, dtype="void", exclusive=False)
-                lv1 = R.call_tir(cls.get_sample_index, (cumsum, 
uniform_sample), out_sinfo=R.Tensor((4, 1), dtype="int64"))
-                gv1: R.Tuple(R.Tensor((4, 1), dtype="int64"), 
R.Tuple(R.Object)) = lv1, (_io,)
+                cumsum: R.Tensor((3, 5), dtype="float32") = R.cumsum(prob, 
axis=1, dtype="void", exclusive=0)
+                lv1 = R.call_tir(cls.get_sample_index, (cumsum, 
uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64"))
+                gv1: R.Tuple(R.Tensor((6, 1), dtype="int64"), 
R.Tuple(R.Object)) = lv1, (_io,)
                 R.output(gv1)
             return gv1
     # fmt: on
@@ -903,6 +905,7 @@ def test_multinomial_from_uniform():
             "foo": {
                 "prob": spec.Tensor(prob_shape, "float32"),
                 "uniform_sample": spec.Tensor(sample_shape, "float32"),
+                "sample_indices": spec.Tensor(sample_shape, "int64"),
             }
         },
         debug=True,
@@ -924,62 +927,59 @@ def test_multinomial_from_uniform():
     np_prob = np_rand / np_rand.sum(axis=1, keepdims=True)
     nd_prob = tvm.nd.array(np_prob, dev)
     # special sample to get deterministic results
-    nd_sample = tvm.nd.array(np.array([[1], [0], [0], 
[1]]).astype(np.float32), dev)
-    inputs = [nd_prob, nd_sample, effects]
+    nd_sample = tvm.nd.array(np.array([[1], [0], [1], [1], [0], 
[1]]).astype(np.float32), dev)
+    nd_sample_indices = tvm.nd.array(np.array([[0], [1], [1], [2], [2], 
[2]]).astype(np.int64), dev)
+    inputs = [nd_prob, nd_sample, nd_sample_indices, effects]
     res = vm["foo"](*inputs)
-    tvm.testing.assert_allclose(res[0].numpy(), np.array([[4], [0], [0], 
[4]]).astype(np.int64))
+    tvm.testing.assert_allclose(
+        res[0].numpy(), np.array([[4], [0], [4], [4], [0], 
[4]]).astype(np.int64)
+    )
 
 
 @tvm.testing.requires_gpu
 def test_sample_top_p_top_k_from_sorted_prob():
     prob_shape = (2, 3)
-    sample_shape = (2, 1)
+    sample_shape = (3, 1)
 
     class Model(Module):
         def foo(
-            self, prob: Tensor, index: Tensor, top_p: Tensor, top_k: Tensor, 
uniform_sample: Tensor
+            self,
+            prob: Tensor,
+            index: Tensor,
+            top_p: Tensor,
+            top_k: Tensor,
+            uniform_sample: Tensor,
+            sample_indices: Tensor,
         ):
-            z0 = op.sample_top_p_top_k_from_sorted_prob(prob, index, top_p, 
top_k, uniform_sample)
+            z0 = op.sample_top_p_top_k_from_sorted_prob(
+                prob, index, top_p, top_k, uniform_sample, sample_indices
+            )
             return z0
 
     # fmt: off
     @I.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle, E: T.handle):
+        def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle, E: T.handle, F: T.handle):
             batch, vocab_size = T.int64(), T.int64()
             cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
-            renorm_prob = T.match_buffer(B, (batch, 1))
-            usample = T.match_buffer(C, (batch, 1))
-            indices = T.match_buffer(D, (batch, vocab_size), "int64")
-            output_index = T.match_buffer(E, (batch, 1), "int64")
+            indices = T.match_buffer(B, (batch, vocab_size), "int64")
+            renorm_prob = T.match_buffer(C, (batch, 1))
+            out_batch = T.int64()
+            usample = T.match_buffer(D, (out_batch, 1))
+            sample_indices = T.match_buffer(E, (out_batch, 1), "int64")
+            output_index = T.match_buffer(F, (out_batch, 1), "int64")
             # with T.block("root"):
-            for ax0, ax1 in T.grid(batch, vocab_size):
+            for ax0, ax1 in T.grid(out_batch, vocab_size):
                 with T.block("T_get_index_from_sorted"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(
-                        usample[v_ax0, T.int64(0)],
-                        cumsum_sorted[v_ax0, v_ax1 - T.int64(1) : v_ax1 - 
T.int64(1) + T.int64(2)],
-                        renorm_prob[v_ax0, 0],
-                        indices[
-                            v_ax0,
-                            T.min(T.int64(0), v_ax1) : T.min(T.int64(0), v_ax1)
-                            + (T.max(T.int64(0), v_ax1) + T.int64(1) - 
T.min(T.int64(0), v_ax1)),
-                        ],
-                    )
+                    T.reads(usample[v_ax0, T.int64(0)], 
cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - 
T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], 
renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[v_ax0, 
T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + 
T.int64(1) - T.min(T.int64(0), v_ax1))])
                     T.writes(output_index[v_ax0, 0])
-                    if (
-                        usample[v_ax0, T.int64(0)]
-                        < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
-                        or v_ax1 + T.int64(1) == vocab_size
-                    ):
+                    if usample[v_ax0, T.int64(0)] < 
cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / 
renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == 
vocab_size:
                         if v_ax1 == T.int64(0):
                             output_index[v_ax0, 0] = indices[v_ax0, 0]
                         else:
-                            if (
-                                usample[v_ax0, T.int64(0)]
-                                >= cumsum_sorted[v_ax0, v_ax1 - T.int64(1)] / 
renorm_prob[v_ax0, 0]
-                            ):
+                            if usample[v_ax0, T.int64(0)] >= 
cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / 
renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
                                 output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
 
         @T.prim_func(private=True)
@@ -1015,21 +1015,14 @@ def test_sample_top_p_top_k_from_sorted_prob():
             return gv
 
         @R.function
-        def foo(
-            prob: R.Tensor((2, 3), dtype="float32"),
-            index: R.Tensor((2, 3), dtype="int64"),
-            top_p: R.Tensor((2, 1), dtype="float32"),
-            top_k: R.Tensor((2, 1), dtype="int64"),
-            uniform_sample: R.Tensor((2, 1), dtype="float32"),
-            _io: R.Object,
-        ) -> R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)):
-            R.func_attr({"num_input": 6})
+        def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2, 
3), dtype="int64"), top_p: R.Tensor((2, 1), dtype="float32"), top_k: 
R.Tensor((2, 1), dtype="int64"), uniform_sample: R.Tensor((3, 1), 
dtype="float32"), sample_indices: R.Tensor((3, 1), dtype="int64"), _io: 
R.Object,) -> R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)):
+            R.func_attr({"num_input": 7})
             cls = Expected
             with R.dataflow():
                 cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob, 
axis=1, dtype="void", exclusive=None)
                 lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), 
out_sinfo=R.Tensor((2, 1), dtype="float32"))
-                lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, lv1, 
uniform_sample, index), out_sinfo=R.Tensor((2, 1), dtype="int64"))
-                gv1: R.Tuple(R.Tensor((2, 1), dtype="int64"), 
R.Tuple(R.Object)) = lv2, (_io,)
+                lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, index, 
lv1, uniform_sample, sample_indices), out_sinfo=R.Tensor((3, 1), dtype="int64"))
+                gv1: R.Tuple(R.Tensor((3, 1), dtype="int64"), 
R.Tuple(R.Object)) = lv2, (_io,)
                 R.output(gv1)
             return gv1
     # fmt: on
@@ -1040,9 +1033,10 @@ def test_sample_top_p_top_k_from_sorted_prob():
             "foo": {
                 "prob": spec.Tensor(prob_shape, "float32"),
                 "index": spec.Tensor(prob_shape, "int64"),
-                "top_p": spec.Tensor(sample_shape, "float32"),
-                "top_k": spec.Tensor(sample_shape, "int64"),
+                "top_p": spec.Tensor((prob_shape[0], 1), "float32"),
+                "top_k": spec.Tensor((prob_shape[0], 1), "int64"),
                 "uniform_sample": spec.Tensor(sample_shape, "float32"),
+                "sample_indices": spec.Tensor(sample_shape, "int64"),
             }
         },
         debug=True,
@@ -1063,12 +1057,13 @@ def test_sample_top_p_top_k_from_sorted_prob():
     indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), 
dev)
     top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev)
     top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev)
-    usample = tvm.nd.array(np.array([[0.5], [0.6]]).astype(np.float32), dev)
+    usample = tvm.nd.array(np.array([[0.5], [0.6], [0.7]]).astype(np.float32), 
dev)
+    sample_indices = tvm.nd.array(np.array([[0], [1], [1]]).astype(np.int64), 
dev)
 
-    inputs = [sorted_prob, indices, top_p, top_k, usample, effects]
+    inputs = [sorted_prob, indices, top_p, top_k, usample, sample_indices, 
effects]
 
     res = vm["foo"](*inputs)
-    tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], 
[0]]).astype(np.int64))
+    tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0], 
[0]]).astype(np.int64))
 
 
 @tvm.testing.requires_gpu

Reply via email to