This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 84b3f69edb [Unity][SLM] GPU sampling (#16575)
84b3f69edb is described below
commit 84b3f69edba618d258d51cddd92618538c28ffb4
Author: Yong Wu <[email protected]>
AuthorDate: Fri Feb 23 06:07:28 2024 -0800
[Unity][SLM] GPU sampling (#16575)
This PR adds GPU sampling support to SLM
---
python/tvm/relax/frontend/nn/_tensor_op.py | 16 +
python/tvm/relax/frontend/nn/op.py | 501 +++++++++++++++++++++++++++++
src/runtime/relax_vm/lm_support.cc | 37 +++
tests/python/relax/test_frontend_nn_op.py | 365 ++++++++++++++++++++-
tests/python/relax/test_vm_builtin.py | 57 ++++
5 files changed, 973 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py
b/python/tvm/relax/frontend/nn/_tensor_op.py
index 3a646e29b8..7f44ca2438 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -67,6 +67,22 @@ class _TensorOp:
other = _convert_scalar(other, self)
return _op().divide(self, other)
+ def __lt__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().less(self, other)
+
+ def __le__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().less_equal(self, other)
+
+ def __gt__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().greater(self, other)
+
+ def __ge__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().greater_equal(self, other)
+
def astype(self, dtype):
return _op().astype(self, dtype)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index b6c34ca265..6944fc8535 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -24,6 +24,8 @@ from typing import Any, Callable, Dict, List, Optional,
Sequence, Tuple, TypeVar
import numpy as np
from tvm import tir as _tir
+from tvm.script import tir as T
+from tvm import te
from ... import expr as rx
from ... import op as _op
@@ -1825,3 +1827,502 @@ def print_(tensor: Tensor):
filename, line_number =
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
line_info = f"{filename}:{line_number}"
debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs < rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs <= rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs > rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs >= rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs == rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs != rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") ->
Tensor:
+ """Selecting elements from either the input tensors depending on the value
of the
+ condition.
+
+ For a given position, return the corresponding value in `x1` if
`condition` is True,
+ and return the corresponding value in `x2` otherwise.
+
+ Parameters
+ ----------
+ condition : Tensor
+ When True, yield `x1`; otherwise, yield `x2`.
+ Must be broadcasting compatible with `x1` and `x2`.
+ Must have boolean dtype.
+
+ x1 : Tensor
+ The first input tensor.
+ Must be broadcasting compatible with `condition` and `x2`.
+
+ x2 : Tensor
+ The second input tensor.
+ Must be broadcasting compatible with `condition` and `x1`.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+ return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)
+
+
+def cumsum(
+ data: Tensor,
+ axis: Optional[int] = None,
+ dtype: Optional[str] = None,
+ exclusive: Optional[bool] = None,
+ name: str = "cumsum",
+) -> Tensor:
+ """Numpy style cumsum op. Return the cumulative inclusive sum of the
elements along
+ a given axis.
+
+ Parameters
+ ----------
+ data : Tensor
+ The input data to the operator.
+
+ axis : Optional[int]
+ Axis along which the cumulative sum is computed. The default (None) is
to compute
+ the cumsum over the flattened array.
+
+ dtype : Optional[str]
+ Type of the returned array and of the accumulator in which the
elements are summed.
+ If dtype is not specified, it defaults to the dtype of data.
+
+ exclusive : Optional[bool]
+ If true will return exclusive sum in which the first element is not
+ included.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result has the same size as data, and the same shape as data if
axis is not None.
+ If axis is None, the result is a 1-d array.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ a = [[1, 2, 3], [4, 5, 6]]
+
+ cumsum(a) # if axis is not provided, cumsum is done over the
flattened input.
+ -> [ 1, 3, 6, 10, 15, 21]
+
+ cumsum(a, dtype="float32")
+ -> [ 1., 3., 6., 10., 15., 21.]
+
+ cumsum(a, axis=0) # sum over rows for each of the 3 columns
+ -> [[1, 2, 3],
+ [5, 7, 9]]
+
+ cumsum(a, axis=1)
+ -> [[ 1, 3, 6],
+ [ 4, 9, 15]]
+
+ a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
+ cumsum(a, dtype=int32) # dtype should be provided to get the expected
results
+ -> [1, 1, 2, 2, 3, 4, 4]
+ """
+ return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str
= "int64"):
+ """Returns a tensor where each row contains the index sampled from the
multinomial
+ probability distribution located in the corresponding row of tensor prob.
+
+ Notes
+ -----
+ For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
+ For accurate results, ensure probabilities are between 0 and 1 and sum to
1.
+
+ Parameters
+ ----------
+ prob : Tensor
+ A 2-D tensor of shape (batch, vocab_size) representing probability
distributions.
+ Each row is a distribution across vocabulary for a batch, where:
+ Values range from [0, 1], indicating the probability of each
vocabulary item.
+ The sum of values in each row is 1, forming a valid distribution.
+
+ uniform_sample : Tensor
+ The uniformly sampled 2-D tensor with the shape (batch, 1).
+ Values range from 0 to 1, indicating probabilities sampled uniformly.
+
+ Returns
+ -------
+ result : Tensor
+ The computed tensor with shape (batch, 1).
+
+ Examples
+ --------
+ .. code-block:: python
+
+ prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
+ usample = [[0.4], [0.9]]
+
+ multinomial_from_uniform(prob, usample)
+ -> [[1], [2]]
+ """
+ prob_dtype = prob.dtype
+ sample_dtype = uniform_sample.dtype
+ batch = prob.shape[0]
+
+ @T.prim_func(private=True)
+ def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ usample = T.match_buffer(B, (batch, 1), sample_dtype)
+ output_index = T.match_buffer(C, (batch, 1), dtype)
+
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_sample_index"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.writes(output_index[v_ax0, 0])
+ if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 +
1 == vocab_size:
+ if v_ax1 == 0:
+ output_index[v_ax0, 0] = 0
+ elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
+ output_index[v_ax0, 0] = v_ax1
+
+ cumsum_prob = cumsum(prob, axis=1, exclusive=False)
+
+ return tensor_ir_op(
+ _get_sample_index,
+ "get_sample_index",
+ args=[cumsum_prob, uniform_sample],
+ out=Tensor.placeholder([batch, 1], dtype),
+ )
+
+
+def sample_top_p_top_k_from_sorted_prob(
+ sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor,
uniform_sample: Tensor
+):
+ """Samples indices from a sorted probability tensor based on top_p and
top_k criteria.
+
+ Notes
+ -----
+ For accurate results, ensure probabilities are between 0 and 1 and sum to
1.
+
+ Parameters
+ ----------
+ sorted_prob : Tensor
+ A 2-D tensor, with shape (batch, vocab_size), contains probabilities
+ sorted in descending order.
+
+ sorted_index: Tensor
+ The indices tensor with shape (batch, vocab_size), corresponding to the
+ sorted_prob. Potentially from applying argsort on the original
probability
+ tensor in descending order.
+
+ top_p : Tensor
+ The cumulative probability threshold with shape (batch, 1) for nucleus
sampling.
+
+ top_k :Tensor
+ A tensor with shape (batch, 1), representing the number of top
probabilities
+ to consider for top-k sampling.
+
+ uniform_sample : Tensor
+ Uniformly sampled values with shape (batch, 1) are used to select the
output indices.
+
+ Returns
+ -------
+ result : Tensor
+ The selected indices with shape (batch, 1).
+
+ Examples
+ --------
+ .. code-block:: python
+
+ prob = [[0.1 , 0.4, 0.5],
+ [0.3, 0.3, 0.4]]
+ sorted_prob = [[0.5, 0.4, 0.1],
+ [0.4, 0.3, 0.3]]
+ sorted_index = [[2, 1, 0],
+ [2, 0, 1]]
+ top_p = [[0.6],[0.9]]
+ top_k = [[3],[2]]
+ uniform_sample = [[0.5], [0.6]]
+
+ sample_top_p_top_k_from_sorted_prob(
+ sorted_prob, sorted_index,top_p, top_k, uniform_sample)
+ -> [2, 0]
+
+ """
+ prob_dtype = sorted_prob.dtype
+ index_dtype = sorted_index.dtype
+ batch = sorted_prob.shape[0]
+
+ def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
+ return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
+
+ @T.prim_func(private=True)
+ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ top_p = T.match_buffer(B, (batch, 1), prob_dtype)
+ top_k = T.match_buffer(C, (batch, 1), index_dtype)
+ renorm_prob = T.match_buffer(D, (batch, 1), prob_dtype)
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_renorm_prob"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
+ renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
+ elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1)
== 1:
+ if v_ax1 + 1 == vocab_size:
+ renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
+ elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0,
v_ax1 + 1) == 0:
+ renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]
+
+ @T.prim_func(private=True)
+ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
+ usample = T.match_buffer(C, (batch, 1), prob_dtype)
+ indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
+ output_index = T.match_buffer(E, (batch, 1), index_dtype)
+
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_index_from_sorted"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.writes(output_index[v_ax0, 0])
+ if (
+ usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] /
renorm_prob[v_ax0, 0]
+ or v_ax1 + 1 == vocab_size
+ ):
+ if v_ax1 == 0:
+ output_index[v_ax0, 0] = indices[v_ax0, 0]
+ elif (
+ usample[v_ax0, T.int64(0)]
+ >= cumsum_sorted[v_ax0, v_ax1 - 1] /
renorm_prob[v_ax0, 0]
+ ):
+ output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
+
+ cumsum_sorted = cumsum(sorted_prob, axis=1)
+
+ renorm_prob = tensor_ir_op(
+ _get_renorm_prob,
+ "get_renorm_prob",
+ args=[cumsum_sorted, top_p, top_k],
+ out=Tensor.placeholder(
+ [batch, 1],
+ prob_dtype,
+ ),
+ )
+
+ out_index_in_sorted = tensor_ir_op(
+ _get_index_from_sorted,
+ "get_index_from_sorted",
+ args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
+ out=Tensor.placeholder([batch, 1], index_dtype),
+ )
+ return out_index_in_sorted
+
+
+def renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k):
+ """Renormalizes probabilities after filtering with top_p and top_k,
ensuring
+ they sum up to 1.
+
+ Notes
+ -----
+ For accurate results, ensure probabilities are between 0 and 1 and sum to
1.
+
+ Parameters
+ ----------
+ prob : Tensor
+ A 2-D tensor of shape (batch, vocab_size) representing probability
distributions.
+
+ sorted_prob : Tensor
+ Probabilities sorted in descending order.
+
+ top_p : Tensor
+ The cumulative probability threshold with shape (batch, 1) for nucleus
sampling.
+
+ top_k :Tensor
+ A tensor with shape (batch, 1), representing the number of top
probabilities
+ to consider for top-k sampling.
+
+ Returns
+ -------
+ result : Tensor
+ The filtered and nomalized tensor with the sampe shape as input prob.
+ """
+ prob_dtype = prob.dtype
+ top_k_dtype = top_k.dtype
+ batch = sorted_prob.shape[0]
+
+ def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
+ return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
+
+ @T.prim_func(private=True)
+ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle,
E: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ sorted_prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ cumsum_sorted = T.match_buffer(B, (batch, vocab_size), prob_dtype)
+ top_p = T.match_buffer(C, (batch, 1), prob_dtype)
+ top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
+ cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_renorm_prob"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
+ cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
+ elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1)
== 1:
+ if v_ax1 + 1 == vocab_size:
+ cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]
+ elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0,
v_ax1 + 1) == 0:
+ cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + 1]
+
+ cumsum_sorted = cumsum(sorted_prob, axis=1)
+
+ renorm_cutoff = tensor_ir_op(
+ _get_renorm_cutoff,
+ "get_renorm_cutoff",
+ args=[sorted_prob, cumsum_sorted, top_p, top_k],
+ out=Tensor.placeholder(
+ [batch, 1],
+ prob_dtype,
+ ),
+ )
+
+ filtered_prob = tensor_expr_op(
+ lambda prob, renorm_cutoff: te.compute(
+ prob.shape,
+ lambda i, j: _tir.Select(prob[i, j] >= renorm_cutoff[i, 0],
prob[i, j], 0.0),
+ name="filter_with_top_p_top_k",
+ ),
+ "filter_with_top_p_top_k",
+ args=[prob, renorm_cutoff],
+ )
+ renorm_prob = filtered_prob / sum(filtered_prob, axis=1, keepdims=True)
+ return renorm_prob
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index cfb78006d7..95dca0c6d5 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -496,6 +496,43 @@ int SampleTopPFromProb(NDArray prob, double top_p, double
uniform_sample) {
TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);
+NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) {
+ ICHECK(prob.IsContiguous());
+ ICHECK(uniform_sample.IsContiguous());
+
+ if (prob->device.device_type != kDLCPU) {
+ prob = prob.CopyTo(DLDevice{kDLCPU, 0});
+ }
+ if (uniform_sample->device.device_type != kDLCPU) {
+ uniform_sample = uniform_sample.CopyTo(DLDevice{kDLCPU, 0});
+ }
+
+ ICHECK(prob->device.device_type == kDLCPU);
+ ICHECK(uniform_sample->device.device_type == kDLCPU);
+
+ int64_t batch_size = prob->shape[0];
+ int64_t vocab_size = prob->shape[prob->ndim - 1];
+ const float* pprob = static_cast<float*>(prob->data);
+ const float* psample = static_cast<float*>(uniform_sample->data);
+ NDArray new_array = NDArray::Empty({batch_size, 1}, DataType::Int(64),
uniform_sample->device);
+ int64_t* parray = static_cast<int64_t*>(new_array->data);
+ for (int64_t i = 0; i < batch_size; ++i) {
+ float cum_sum_prob = 0.0f;
+ int64_t prob_idx = 0;
+ for (int64_t j = 0; j < vocab_size; ++j) {
+ prob_idx = j;
+ cum_sum_prob += pprob[i * vocab_size + j];
+ if (cum_sum_prob > psample[i]) {
+ break;
+ }
+ }
+ parray[i] = prob_idx;
+ }
+ return new_array;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform").set_body_typed(MultinomialFromUniform);
+
// This is an inplace operation.
void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty)
{
ICHECK(logits.IsContiguous());
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 650d8ace30..3457989a55 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring, invalid-name
+import numpy as np
import tvm
import tvm.testing
from tvm import relax, tir
@@ -61,11 +62,18 @@ def test_binary():
z4 = op.maximum(x, y)
z5 = op.minimum(x, y)
z6 = op.subtract(x, y)
- return (z0, z1, z2, z3, z4, z5, z6)
+ z7 = op.greater(x, y)
+ z8 = op.greater_equal(x, y)
+ z9 = op.less(x, y)
+ z10 = op.less_equal(x, y)
+ z11 = op.equal(x, y)
+ z12 = op.not_equal(x, y)
+
+ return (z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12)
# fmt: off
@R.function
- def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")), R.Tuple(R.Object)):
+ def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1),
dtype="float32"), _io: R.Object):
R.func_attr({"num_input": 3})
with R.dataflow():
add: R.Tensor((10, 10), dtype="float32") = R.add(x, y)
@@ -75,7 +83,13 @@ def test_binary():
maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y)
minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y)
subtract: R.Tensor((10, 10), dtype="float32") = R.subtract(x, y)
- gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")),
R.Tuple(R.Object)) = (add, mul, divide, matmul, maximum, minimum, subtract),
(_io,)
+ greater: R.Tensor((10, 10), dtype="bool") = x > y
+ greater_equal: R.Tensor((10, 10), dtype="bool") = x >= y
+ less: R.Tensor((10, 10), dtype="bool") = x < y
+ less_equal: R.Tensor((10, 10), dtype="bool") = x <= y
+ equal: R.Tensor((10, 10), dtype="bool") = R.equal(x, y)
+ not_equal: R.Tensor((10, 10), dtype="bool") = R.not_equal(x, y)
+ gv1 = (add, mul, divide, matmul, maximum, minimum, subtract,
greater, greater_equal, less, less_equal, equal, not_equal), (_io,)
R.output(gv1)
return gv1
# fmt: on
@@ -829,5 +843,350 @@ def test_empty():
vm["test"](*effects)
[email protected]_gpu
+def test_multinomial_from_uniform():
+
+ prob_shape = (4, 5)
+ sample_shape = (4, 1)
+
+ class Model(Module):
+ def foo(self, prob: Tensor, uniform_sample: Tensor):
+ z0 = op.multinomial_from_uniform(prob, uniform_sample)
+ return z0
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ prob = T.match_buffer(A, (batch, vocab_size))
+ usample = T.match_buffer(B, (batch, 1))
+ output_index = T.match_buffer(C, (batch, 1), "int64")
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_sample_index"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(usample[v_ax0, T.int64(0)], prob[v_ax0, v_ax1 -
T.int64(1):v_ax1 - T.int64(1) + T.int64(2)])
+ T.writes(output_index[v_ax0, 0])
+ if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or
v_ax1 + T.int64(1) == vocab_size:
+ if v_ax1 == T.int64(0):
+ output_index[v_ax0, 0] = T.int64(0)
+ else:
+ if usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1
- T.int64(1)]:
+ output_index[v_ax0, 0] = v_ax1
+
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample:
R.Tensor((4, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((4, 1),
dtype="int64"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 3})
+ cls = Expected
+ with R.dataflow():
+ cumsum: R.Tensor((4, 5), dtype="float32") = R.cumsum(prob,
axis=1, dtype="void", exclusive=False)
+ lv1 = R.call_tir(cls.get_sample_index, (cumsum,
uniform_sample), out_sinfo=R.Tensor((4, 1), dtype="int64"))
+ gv1: R.Tuple(R.Tensor((4, 1), dtype="int64"),
R.Tuple(R.Object)) = lv1, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ mod, _ = m.export_tvm(
+ spec={
+ "foo": {
+ "prob": spec.Tensor(prob_shape, "float32"),
+ "uniform_sample": spec.Tensor(sample_shape, "float32"),
+ }
+ },
+ debug=True,
+ )
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+ target = tvm.target.Target("cuda -libs=thrust", host="llvm")
+ with target:
+ mod = tir.transform.DefaultGPUSchedule()(mod)
+ ex = relax.build(mod, target)
+ dev = tvm.cuda(0)
+ vm = relax.VirtualMachine(ex, dev)
+
+ effects = vm["_initialize_effect"]()
+
+ np_rand = np.random.rand(*prob_shape).astype(np.float32)
+ # normalize it to get the random prob
+ np_prob = np_rand / np_rand.sum(axis=1, keepdims=True)
+ nd_prob = tvm.nd.array(np_prob, dev)
+ # special sample to get deterministic results
+ nd_sample = tvm.nd.array(np.array([[1], [0], [0],
[1]]).astype(np.float32), dev)
+ inputs = [nd_prob, nd_sample, effects]
+ res = vm["foo"](*inputs)
+ tvm.testing.assert_allclose(res[0].numpy(), np.array([[4], [0], [0],
[4]]).astype(np.int64))
+
+
[email protected]_gpu
+def test_sample_top_p_top_k_from_sorted_prob():
+ prob_shape = (2, 3)
+ sample_shape = (2, 1)
+
+ class Model(Module):
+ def foo(
+ self, prob: Tensor, index: Tensor, top_p: Tensor, top_k: Tensor,
uniform_sample: Tensor
+ ):
+ z0 = op.sample_top_p_top_k_from_sorted_prob(prob, index, top_p,
top_k, uniform_sample)
+ return z0
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
+ renorm_prob = T.match_buffer(B, (batch, 1))
+ usample = T.match_buffer(C, (batch, 1))
+ indices = T.match_buffer(D, (batch, vocab_size), "int64")
+ output_index = T.match_buffer(E, (batch, 1), "int64")
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_index_from_sorted"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(
+ usample[v_ax0, T.int64(0)],
+ cumsum_sorted[v_ax0, v_ax1 - T.int64(1) : v_ax1 -
T.int64(1) + T.int64(2)],
+ renorm_prob[v_ax0, 0],
+ indices[
+ v_ax0,
+ T.min(T.int64(0), v_ax1) : T.min(T.int64(0), v_ax1)
+ + (T.max(T.int64(0), v_ax1) + T.int64(1) -
T.min(T.int64(0), v_ax1)),
+ ],
+ )
+ T.writes(output_index[v_ax0, 0])
+ if (
+ usample[v_ax0, T.int64(0)]
+ < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
+ or v_ax1 + T.int64(1) == vocab_size
+ ):
+ if v_ax1 == T.int64(0):
+ output_index[v_ax0, 0] = indices[v_ax0, 0]
+ else:
+ if (
+ usample[v_ax0, T.int64(0)]
+ >= cumsum_sorted[v_ax0, v_ax1 - T.int64(1)] /
renorm_prob[v_ax0, 0]
+ ):
+ output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
+
+ @T.prim_func(private=True)
+ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D:
T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
+ top_p = T.match_buffer(B, (batch, 1))
+ top_k = T.match_buffer(C, (batch, 1), "int64")
+ renorm_prob = T.match_buffer(D, (batch, 1))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_renorm_prob"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0),
v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1))
+ (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) -
T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0],
top_k[v_ax0, 0])
+ T.writes(renorm_prob[v_ax0, 0])
+ if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and
top_k[v_ax0, 0] > T.int64(1)) == T.bool(False):
+ renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
+ else:
+ if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and
v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True):
+ if v_ax1 + T.int64(1) == vocab_size:
+ renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0,
v_ax1]
+ else:
+ if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] <
top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) ==
T.bool(False):
+ renorm_prob[v_ax0, 0] =
cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]
+
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def foo(
+ prob: R.Tensor((2, 3), dtype="float32"),
+ index: R.Tensor((2, 3), dtype="int64"),
+ top_p: R.Tensor((2, 1), dtype="float32"),
+ top_k: R.Tensor((2, 1), dtype="int64"),
+ uniform_sample: R.Tensor((2, 1), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 6})
+ cls = Expected
+ with R.dataflow():
+ cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob,
axis=1, dtype="void", exclusive=None)
+ lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k),
out_sinfo=R.Tensor((2, 1), dtype="float32"))
+ lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, lv1,
uniform_sample, index), out_sinfo=R.Tensor((2, 1), dtype="int64"))
+ gv1: R.Tuple(R.Tensor((2, 1), dtype="int64"),
R.Tuple(R.Object)) = lv2, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ mod, _ = m.export_tvm(
+ spec={
+ "foo": {
+ "prob": spec.Tensor(prob_shape, "float32"),
+ "index": spec.Tensor(prob_shape, "int64"),
+ "top_p": spec.Tensor(sample_shape, "float32"),
+ "top_k": spec.Tensor(sample_shape, "int64"),
+ "uniform_sample": spec.Tensor(sample_shape, "float32"),
+ }
+ },
+ debug=True,
+ )
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+ target = tvm.target.Target("cuda -libs=thrust", host="llvm")
+ with target:
+ mod = tir.transform.DefaultGPUSchedule()(mod)
+
+ ex = relax.build(mod, target)
+ dev = tvm.cuda(0)
+ vm = relax.VirtualMachine(ex, dev)
+
+ effects = vm["_initialize_effect"]()
+ sorted_prob = tvm.nd.array(np.array([[0.5, 0.4, 0.1], [0.4, 0.3,
0.3]]).astype(np.float32), dev)
+ indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64),
dev)
+ top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev)
+ top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev)
+ usample = tvm.nd.array(np.array([[0.5], [0.6]]).astype(np.float32), dev)
+
+ inputs = [sorted_prob, indices, top_p, top_k, usample, effects]
+
+ res = vm["foo"](*inputs)
+ tvm.testing.assert_allclose(res[0].numpy(), np.array([[2],
[0]]).astype(np.int64))
+
+
[email protected]_gpu
+def test_renormalize_top_p_top_k_prob():
+ prob_shape = (2, 3)
+ sample_shape = (2, 1)
+
+ class Model(Module):
+ def foo(
+ self,
+ prob: Tensor,
+ sorted_prob: Tensor,
+ top_p: Tensor,
+ top_k: Tensor,
+ ):
+ z0 = op.renormalize_top_p_top_k_prob(prob, sorted_prob, top_p,
top_k)
+ return z0
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def filter_with_top_p_top_k(A: T.Buffer((T.int64(2), T.int64(3)),
"float32"), B: T.Buffer((T.int64(2), T.int64(1)), "float32"),
filter_with_top_p_top_k: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i, j in T.grid(T.int64(2), T.int64(3)):
+ with T.block("filter_with_top_p_top_k"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(B[v_i, T.int64(0)], A[v_i, v_j])
+ T.writes(filter_with_top_p_top_k[v_i, v_j])
+ filter_with_top_p_top_k[v_i, v_j] = T.Select(B[v_i,
T.int64(0)] <= A[v_i, v_j], A[v_i, v_j], T.float32(0))
+
+ @T.prim_func(private=True)
+ def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ sorted_prob = T.match_buffer(A, (batch, vocab_size))
+ cumsum_sorted = T.match_buffer(B, (batch, vocab_size))
+ top_p = T.match_buffer(C, (batch, 1))
+ top_k = T.match_buffer(D, (batch, 1), "int64")
+ cutoff = T.match_buffer(E, (batch, 1))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_renorm_prob"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0),
v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1))
+ (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) -
T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0],
top_k[v_ax0, 0], sorted_prob[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 +
T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) +
(T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + [...]
+ T.writes(cutoff[v_ax0, 0])
+ if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and
top_k[v_ax0, 0] > T.int64(1)) == T.bool(False):
+ cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
+ else:
+ if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and
v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True):
+ if v_ax1 + T.int64(1) == vocab_size:
+ cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]
+ else:
+ if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] <
top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) ==
T.bool(False):
+ cutoff[v_ax0, 0] = sorted_prob[v_ax0,
v_ax1 + T.int64(1)]
+
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob:
R.Tensor((2, 3), dtype="float32"), top_p: R.Tensor((2, 1), dtype="float32"),
top_k: R.Tensor((2, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((2,
3), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 5})
+ cls = Expected
+ with R.dataflow():
+ cumsum: R.Tensor((2, 3), dtype="float32") =
R.cumsum(sorted_prob, axis=1, dtype="void", exclusive=None)
+ lv1 = R.call_tir(cls.get_renorm_cutoff, (sorted_prob, cumsum,
top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32"))
+ lv2 = R.call_tir(cls.filter_with_top_p_top_k, (prob, lv1),
out_sinfo=R.Tensor((2, 3), dtype="float32"))
+ sum: R.Tensor((2, 1), dtype="float32") = R.sum(lv2, axis=[1],
keepdims=True)
+ divide: R.Tensor((2, 3), dtype="float32") = R.divide(lv2, sum)
+ gv1: R.Tuple(R.Tensor((2, 3), dtype="float32"),
R.Tuple(R.Object)) = divide, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ mod, _ = m.export_tvm(
+ spec={
+ "foo": {
+ "prob": spec.Tensor(prob_shape, "float32"),
+ "sorted_prob": spec.Tensor(prob_shape, "float32"),
+ "top_p": spec.Tensor(sample_shape, "float32"),
+ "top_k": spec.Tensor(sample_shape, "int64"),
+ }
+ },
+ debug=True,
+ )
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+ target = tvm.target.Target("cuda -libs=thrust", host="llvm")
+ with target:
+ mod = relax.backend.DispatchSortScan()(mod)
+ mod = relax.transform.LegalizeOps()(mod)
+ mod = tir.transform.DefaultGPUSchedule()(mod)
+
+ ex = relax.build(mod, target)
+ dev = tvm.cuda(0)
+ vm = relax.VirtualMachine(ex, dev)
+
+ effects = vm["_initialize_effect"]()
+ prob = tvm.nd.array(np.array([[0.2, 0.3, 0.5], [0.3, 0.3,
0.4]]).astype(np.float32), dev)
+ sorted_prob = tvm.nd.array(np.array([[0.5, 0.3, 0.2], [0.4, 0.3,
0.3]]).astype(np.float32), dev)
+ top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev)
+ top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev)
+
+ inputs = [prob, sorted_prob, top_p, top_k, effects]
+
+ res = vm["foo"](*inputs)
+ tvm.testing.assert_allclose(
+ res[0].numpy(), np.array([[0, 0.375, 0.625], [0.3, 0.3,
0.4]]).astype(np.float32)
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_vm_builtin.py
b/tests/python/relax/test_vm_builtin.py
new file mode 100644
index 0000000000..f786f707af
--- /dev/null
+++ b/tests/python/relax/test_vm_builtin.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+
+
+def test_multinomial_from_uniform():
+ @tvm.script.ir_module
+ class CallSample:
+ @R.function
+ def foo(x: R.Tensor((3, 5), "float32"), y: R.Tensor((3, 1),
"float32")):
+ z = R.call_pure_packed(
+ "vm.builtin.multinomial_from_uniform",
+ x,
+ y,
+ sinfo_args=(R.Tensor((3, 1), dtype="int64")),
+ )
+ return z
+
+ mod = CallSample
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.build(mod, target)
+ np_rand = np.random.rand(3, 5).astype(np.float32)
+ # normalize it to get the random prob
+ np_prob = np_rand / np_rand.sum(axis=1, keepdims=True)
+ nd_prob = tvm.nd.array(np_prob)
+ # special sample to get deterministic results
+ nd_sample = tvm.nd.array(np.array([[1.0], [0], [1]]).astype(np.float32))
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["foo"](nd_prob, nd_sample)
+ tvm.testing.assert_allclose(res.numpy(), np.array([[4], [0],
[4]]).astype(np.int64))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()