This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fe5a350f47 [Relax] add sample_indices in sampling (#16675)
fe5a350f47 is described below
commit fe5a350f47fd5b15f8a8a8eeb33b4b313f5c35a9
Author: Yong Wu <[email protected]>
AuthorDate: Tue Mar 5 09:49:49 2024 -0800
[Relax] add sample_indices in sampling (#16675)
---
python/tvm/relax/frontend/nn/op.py | 134 +++++++++++++++++++++++-------
tests/python/relax/test_frontend_nn_op.py | 121 +++++++++++++--------------
2 files changed, 163 insertions(+), 92 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 6944fc8535..ae880190ad 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -2057,7 +2057,12 @@ def cumsum(
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
-def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str
= "int64"):
+def multinomial_from_uniform(
+ prob: Tensor,
+ uniform_sample: Tensor,
+ sample_indices: Optional[Tensor] = None,
+ dtype: str = "int64",
+):
"""Returns a tensor where each row contains the index sampled from the
multinomial
probability distribution located in the corresponding row of tensor prob.
@@ -2075,13 +2080,25 @@ def multinomial_from_uniform(prob: Tensor,
uniform_sample: Tensor, dtype: str =
The sum of values in each row is 1, forming a valid distribution.
uniform_sample : Tensor
- The uniformly sampled 2-D tensor with the shape (batch, 1).
+ The uniformly sampled 2-D tensor with the shape (n, 1).
Values range from 0 to 1, indicating probabilities sampled uniformly.
+ sample_indices : Optional[Tensor]
+ The 2-D tensor with the shape [n, 1], which indicates the specific
+ probability distribution to sample from. The value of sample_indices[i]
+ determines that the ith token should be sampled from the
sample_indices[i]th
+ probability distribution. For instance, if there are 3 distinct
probability
+ distributions and the requirement is to sample 2, 3, and 4 tokens from
each,
+ then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
+
+ dtype : str
+ The data type of output tensor.
+
+
Returns
-------
result : Tensor
- The computed tensor with shape (batch, 1).
+ The computed tensor with shape (n, 1).
Examples
--------
@@ -2089,29 +2106,52 @@ def multinomial_from_uniform(prob: Tensor,
uniform_sample: Tensor, dtype: str =
prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
usample = [[0.4], [0.9]]
+ sample_indices = [[0], [1]]
multinomial_from_uniform(prob, usample)
-> [[1], [2]]
+ multinomial_from_uniform(prob, usample, sample_indices)
+ -> [[1], [2]]
"""
prob_dtype = prob.dtype
sample_dtype = uniform_sample.dtype
- batch = prob.shape[0]
+ out_batch = uniform_sample.shape[0]
+
+ if sample_indices is not None:
+ assert (
+ sample_indices.shape == uniform_sample.shape
+ ), "The shape of sample_indices must match the shape of
uniform_sample."
+ else:
+ assert (
+ prob.shape[0] == uniform_sample.shape[0]
+ ), "Number of samples must match the number of probability
distributions."
+ sample_indices =
Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1))
+
+ sample_indices_dtype = sample_indices.dtype
@T.prim_func(private=True)
- def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+ def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
batch, vocab_size = T.int64(), T.int64()
prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
- usample = T.match_buffer(B, (batch, 1), sample_dtype)
- output_index = T.match_buffer(C, (batch, 1), dtype)
+ out_batch = T.int64()
+ usample = T.match_buffer(B, (out_batch, 1), sample_dtype)
+ sample_indices = T.match_buffer(C, (out_batch, 1),
sample_indices_dtype)
+ output_index = T.match_buffer(D, (out_batch, 1), dtype)
- for ax0, ax1 in T.grid(batch, vocab_size):
+ for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_sample_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
- if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 +
1 == vocab_size:
+ if (
+ usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0,
T.int64(0)], v_ax1]
+ or v_ax1 + 1 == vocab_size
+ ):
if v_ax1 == 0:
output_index[v_ax0, 0] = 0
- elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
+ elif (
+ usample[v_ax0, T.int64(0)]
+ >= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
+ ):
output_index[v_ax0, 0] = v_ax1
cumsum_prob = cumsum(prob, axis=1, exclusive=False)
@@ -2119,13 +2159,18 @@ def multinomial_from_uniform(prob: Tensor,
uniform_sample: Tensor, dtype: str =
return tensor_ir_op(
_get_sample_index,
"get_sample_index",
- args=[cumsum_prob, uniform_sample],
- out=Tensor.placeholder([batch, 1], dtype),
+ args=[cumsum_prob, uniform_sample, sample_indices],
+ out=Tensor.placeholder([out_batch, 1], dtype),
)
def sample_top_p_top_k_from_sorted_prob(
- sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor,
uniform_sample: Tensor
+ sorted_prob: Tensor,
+ sorted_index: Tensor,
+ top_p: Tensor,
+ top_k: Tensor,
+ uniform_sample: Tensor,
+ sample_indices: Optional[Tensor] = None,
):
"""Samples indices from a sorted probability tensor based on top_p and
top_k criteria.
@@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob(
to consider for top-k sampling.
uniform_sample : Tensor
- Uniformly sampled values with shape (batch, 1) are used to select the
output indices.
+ Uniformly sampled values with shape (n, 1) are used to select the
output indices.
+
+ sample_indices : Optional[Tensor]
+ The 2-D tensor with the shape [n, 1], which indicates the specific
+ probability distribution to sample from. The value of sample_indices[i]
+ determines that the ith token should be sampled from the
sample_indices[i]th
+ probability distribution. For instance, if there are 3 distinct
probability
+ distributions and the requirement is to sample 2, 3, and 4 tokens from
each,
+ then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
Returns
-------
result : Tensor
- The selected indices with shape (batch, 1).
+ The selected indices with shape (n, 1).
Examples
--------
@@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob(
top_p = [[0.6],[0.9]]
top_k = [[3],[2]]
uniform_sample = [[0.5], [0.6]]
+ sample_indices = [[0], [1]]
sample_top_p_top_k_from_sorted_prob(
- sorted_prob, sorted_index,top_p, top_k, uniform_sample)
+ sorted_prob, sorted_index,top_p, top_k, uniform_sample,
sample_indices)
-> [2, 0]
"""
prob_dtype = sorted_prob.dtype
index_dtype = sorted_index.dtype
- batch = sorted_prob.shape[0]
+ prob_batch = sorted_prob.shape[0]
+ out_batch = uniform_sample.shape[0]
+
+ if sample_indices is not None:
+ assert (
+ sample_indices.shape == uniform_sample.shape
+ ), "The shape of sample_indices must match the shape of
uniform_sample."
+ else:
+ assert (
+ sorted_prob.shape[0] == uniform_sample.shape[0]
+ ), "Number of samples must match the number of probability
distributions."
+ sample_indices = Tensor.from_const(
+ np.arange(out_batch).reshape(out_batch, 1).astype(np.int64)
+ )
+ print("sample_indices: ", sample_indices)
+ sample_indices_dtype = sample_indices.dtype
def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
@@ -2204,27 +2273,34 @@ def sample_top_p_top_k_from_sorted_prob(
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]
@T.prim_func(private=True)
- def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle):
+ def _get_index_from_sorted(
+ A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F:
T.handle
+ ):
batch, vocab_size = T.int64(), T.int64()
+ out_batch = T.int64()
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
- renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
- usample = T.match_buffer(C, (batch, 1), prob_dtype)
- indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
- output_index = T.match_buffer(E, (batch, 1), index_dtype)
+ indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
+ renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
+ usample = T.match_buffer(D, (out_batch, 1), prob_dtype)
+ sample_indices = T.match_buffer(E, (out_batch, 1),
sample_indices_dtype)
+ output_index = T.match_buffer(F, (out_batch, 1), index_dtype)
- for ax0, ax1 in T.grid(batch, vocab_size):
+ for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if (
- usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] /
renorm_prob[v_ax0, 0]
+ usample[v_ax0, T.int64(0)]
+ < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1]
+ / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[v_ax0, 0]
elif (
usample[v_ax0, T.int64(0)]
- >= cumsum_sorted[v_ax0, v_ax1 - 1] /
renorm_prob[v_ax0, 0]
+ >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)],
v_ax1 - 1]
+ / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
):
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
@@ -2235,7 +2311,7 @@ def sample_top_p_top_k_from_sorted_prob(
"get_renorm_prob",
args=[cumsum_sorted, top_p, top_k],
out=Tensor.placeholder(
- [batch, 1],
+ [prob_batch, 1],
prob_dtype,
),
)
@@ -2243,8 +2319,8 @@ def sample_top_p_top_k_from_sorted_prob(
out_index_in_sorted = tensor_ir_op(
_get_index_from_sorted,
"get_index_from_sorted",
- args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
- out=Tensor.placeholder([batch, 1], index_dtype),
+ args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample,
sample_indices],
+ out=Tensor.placeholder([out_batch, 1], index_dtype),
)
return out_index_in_sorted
@@ -2293,7 +2369,7 @@ def renormalize_top_p_top_k_prob(prob, sorted_prob,
top_p, top_k):
top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
- with T.block("T_get_renorm_prob"):
+ with T.block("T_get_renorm_cutoff"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 3457989a55..0d579163cd 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -846,34 +846,36 @@ def test_empty():
@tvm.testing.requires_gpu
def test_multinomial_from_uniform():
- prob_shape = (4, 5)
- sample_shape = (4, 1)
+ prob_shape = (3, 5)
+ sample_shape = (6, 1)
class Model(Module):
- def foo(self, prob: Tensor, uniform_sample: Tensor):
- z0 = op.multinomial_from_uniform(prob, uniform_sample)
+ def foo(self, prob: Tensor, uniform_sample: Tensor, sample_indices:
Tensor):
+ z0 = op.multinomial_from_uniform(prob, uniform_sample,
sample_indices)
return z0
# fmt: off
@I.ir_module
class Expected:
@T.prim_func(private=True)
- def get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+ def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D:
T.handle):
batch, vocab_size = T.int64(), T.int64()
prob = T.match_buffer(A, (batch, vocab_size))
- usample = T.match_buffer(B, (batch, 1))
- output_index = T.match_buffer(C, (batch, 1), "int64")
+ out_batch = T.int64()
+ usample = T.match_buffer(B, (out_batch, 1))
+ sample_indices = T.match_buffer(C, (out_batch, 1), "int64")
+ output_index = T.match_buffer(D, (out_batch, 1), "int64")
# with T.block("root"):
- for ax0, ax1 in T.grid(batch, vocab_size):
+ for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_sample_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(usample[v_ax0, T.int64(0)], prob[v_ax0, v_ax1 -
T.int64(1):v_ax1 - T.int64(1) + T.int64(2)])
+ T.reads(usample[v_ax0, T.int64(0)],
prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) +
T.int64(2)], sample_indices[v_ax0, T.int64(0)])
T.writes(output_index[v_ax0, 0])
- if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or
v_ax1 + T.int64(1) == vocab_size:
+ if usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0,
T.int64(0)], v_ax1] or v_ax1 + T.int64(1) == vocab_size:
if v_ax1 == T.int64(0):
output_index[v_ax0, 0] = T.int64(0)
else:
- if usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1
- T.int64(1)]:
+ if usample[v_ax0, T.int64(0)] >=
prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)]:
output_index[v_ax0, 0] = v_ax1
@R.function
@@ -886,13 +888,13 @@ def test_multinomial_from_uniform():
return gv
@R.function
- def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample:
R.Tensor((4, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((4, 1),
dtype="int64"), R.Tuple(R.Object)):
- R.func_attr({"num_input": 3})
+ def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample:
R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1),
dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((6, 1), dtype="int64"),
R.Tuple(R.Object)):
+ R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
- cumsum: R.Tensor((4, 5), dtype="float32") = R.cumsum(prob,
axis=1, dtype="void", exclusive=False)
- lv1 = R.call_tir(cls.get_sample_index, (cumsum,
uniform_sample), out_sinfo=R.Tensor((4, 1), dtype="int64"))
- gv1: R.Tuple(R.Tensor((4, 1), dtype="int64"),
R.Tuple(R.Object)) = lv1, (_io,)
+ cumsum: R.Tensor((3, 5), dtype="float32") = R.cumsum(prob,
axis=1, dtype="void", exclusive=0)
+ lv1 = R.call_tir(cls.get_sample_index, (cumsum,
uniform_sample, sample_indices), out_sinfo=R.Tensor((6, 1), dtype="int64"))
+ gv1: R.Tuple(R.Tensor((6, 1), dtype="int64"),
R.Tuple(R.Object)) = lv1, (_io,)
R.output(gv1)
return gv1
# fmt: on
@@ -903,6 +905,7 @@ def test_multinomial_from_uniform():
"foo": {
"prob": spec.Tensor(prob_shape, "float32"),
"uniform_sample": spec.Tensor(sample_shape, "float32"),
+ "sample_indices": spec.Tensor(sample_shape, "int64"),
}
},
debug=True,
@@ -924,62 +927,59 @@ def test_multinomial_from_uniform():
np_prob = np_rand / np_rand.sum(axis=1, keepdims=True)
nd_prob = tvm.nd.array(np_prob, dev)
# special sample to get deterministic results
- nd_sample = tvm.nd.array(np.array([[1], [0], [0],
[1]]).astype(np.float32), dev)
- inputs = [nd_prob, nd_sample, effects]
+ nd_sample = tvm.nd.array(np.array([[1], [0], [1], [1], [0],
[1]]).astype(np.float32), dev)
+ nd_sample_indices = tvm.nd.array(np.array([[0], [1], [1], [2], [2],
[2]]).astype(np.int64), dev)
+ inputs = [nd_prob, nd_sample, nd_sample_indices, effects]
res = vm["foo"](*inputs)
- tvm.testing.assert_allclose(res[0].numpy(), np.array([[4], [0], [0],
[4]]).astype(np.int64))
+ tvm.testing.assert_allclose(
+ res[0].numpy(), np.array([[4], [0], [4], [4], [0],
[4]]).astype(np.int64)
+ )
@tvm.testing.requires_gpu
def test_sample_top_p_top_k_from_sorted_prob():
prob_shape = (2, 3)
- sample_shape = (2, 1)
+ sample_shape = (3, 1)
class Model(Module):
def foo(
- self, prob: Tensor, index: Tensor, top_p: Tensor, top_k: Tensor,
uniform_sample: Tensor
+ self,
+ prob: Tensor,
+ index: Tensor,
+ top_p: Tensor,
+ top_k: Tensor,
+ uniform_sample: Tensor,
+ sample_indices: Tensor,
):
- z0 = op.sample_top_p_top_k_from_sorted_prob(prob, index, top_p,
top_k, uniform_sample)
+ z0 = op.sample_top_p_top_k_from_sorted_prob(
+ prob, index, top_p, top_k, uniform_sample, sample_indices
+ )
return z0
# fmt: off
@I.ir_module
class Expected:
@T.prim_func(private=True)
- def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle):
+ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle, F: T.handle):
batch, vocab_size = T.int64(), T.int64()
cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
- renorm_prob = T.match_buffer(B, (batch, 1))
- usample = T.match_buffer(C, (batch, 1))
- indices = T.match_buffer(D, (batch, vocab_size), "int64")
- output_index = T.match_buffer(E, (batch, 1), "int64")
+ indices = T.match_buffer(B, (batch, vocab_size), "int64")
+ renorm_prob = T.match_buffer(C, (batch, 1))
+ out_batch = T.int64()
+ usample = T.match_buffer(D, (out_batch, 1))
+ sample_indices = T.match_buffer(E, (out_batch, 1), "int64")
+ output_index = T.match_buffer(F, (out_batch, 1), "int64")
# with T.block("root"):
- for ax0, ax1 in T.grid(batch, vocab_size):
+ for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(
- usample[v_ax0, T.int64(0)],
- cumsum_sorted[v_ax0, v_ax1 - T.int64(1) : v_ax1 -
T.int64(1) + T.int64(2)],
- renorm_prob[v_ax0, 0],
- indices[
- v_ax0,
- T.min(T.int64(0), v_ax1) : T.min(T.int64(0), v_ax1)
- + (T.max(T.int64(0), v_ax1) + T.int64(1) -
T.min(T.int64(0), v_ax1)),
- ],
- )
+ T.reads(usample[v_ax0, T.int64(0)],
cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 -
T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)],
renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[v_ax0,
T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) +
T.int64(1) - T.min(T.int64(0), v_ax1))])
T.writes(output_index[v_ax0, 0])
- if (
- usample[v_ax0, T.int64(0)]
- < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
- or v_ax1 + T.int64(1) == vocab_size
- ):
+ if usample[v_ax0, T.int64(0)] <
cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] /
renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) ==
vocab_size:
if v_ax1 == T.int64(0):
output_index[v_ax0, 0] = indices[v_ax0, 0]
else:
- if (
- usample[v_ax0, T.int64(0)]
- >= cumsum_sorted[v_ax0, v_ax1 - T.int64(1)] /
renorm_prob[v_ax0, 0]
- ):
+ if usample[v_ax0, T.int64(0)] >=
cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] /
renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
@T.prim_func(private=True)
@@ -1015,21 +1015,14 @@ def test_sample_top_p_top_k_from_sorted_prob():
return gv
@R.function
- def foo(
- prob: R.Tensor((2, 3), dtype="float32"),
- index: R.Tensor((2, 3), dtype="int64"),
- top_p: R.Tensor((2, 1), dtype="float32"),
- top_k: R.Tensor((2, 1), dtype="int64"),
- uniform_sample: R.Tensor((2, 1), dtype="float32"),
- _io: R.Object,
- ) -> R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)):
- R.func_attr({"num_input": 6})
+ def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2,
3), dtype="int64"), top_p: R.Tensor((2, 1), dtype="float32"), top_k:
R.Tensor((2, 1), dtype="int64"), uniform_sample: R.Tensor((3, 1),
dtype="float32"), sample_indices: R.Tensor((3, 1), dtype="int64"), _io:
R.Object,) -> R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 7})
cls = Expected
with R.dataflow():
cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob,
axis=1, dtype="void", exclusive=None)
lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k),
out_sinfo=R.Tensor((2, 1), dtype="float32"))
- lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, lv1,
uniform_sample, index), out_sinfo=R.Tensor((2, 1), dtype="int64"))
- gv1: R.Tuple(R.Tensor((2, 1), dtype="int64"),
R.Tuple(R.Object)) = lv2, (_io,)
+ lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, index,
lv1, uniform_sample, sample_indices), out_sinfo=R.Tensor((3, 1), dtype="int64"))
+ gv1: R.Tuple(R.Tensor((3, 1), dtype="int64"),
R.Tuple(R.Object)) = lv2, (_io,)
R.output(gv1)
return gv1
# fmt: on
@@ -1040,9 +1033,10 @@ def test_sample_top_p_top_k_from_sorted_prob():
"foo": {
"prob": spec.Tensor(prob_shape, "float32"),
"index": spec.Tensor(prob_shape, "int64"),
- "top_p": spec.Tensor(sample_shape, "float32"),
- "top_k": spec.Tensor(sample_shape, "int64"),
+ "top_p": spec.Tensor((prob_shape[0], 1), "float32"),
+ "top_k": spec.Tensor((prob_shape[0], 1), "int64"),
"uniform_sample": spec.Tensor(sample_shape, "float32"),
+ "sample_indices": spec.Tensor(sample_shape, "int64"),
}
},
debug=True,
@@ -1063,12 +1057,13 @@ def test_sample_top_p_top_k_from_sorted_prob():
indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64),
dev)
top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev)
top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev)
- usample = tvm.nd.array(np.array([[0.5], [0.6]]).astype(np.float32), dev)
+ usample = tvm.nd.array(np.array([[0.5], [0.6], [0.7]]).astype(np.float32),
dev)
+ sample_indices = tvm.nd.array(np.array([[0], [1], [1]]).astype(np.int64),
dev)
- inputs = [sorted_prob, indices, top_p, top_k, usample, effects]
+ inputs = [sorted_prob, indices, top_p, top_k, usample, sample_indices,
effects]
res = vm["foo"](*inputs)
- tvm.testing.assert_allclose(res[0].numpy(), np.array([[2],
[0]]).astype(np.int64))
+ tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0],
[0]]).astype(np.int64))
@tvm.testing.requires_gpu