This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 84b3f69edb [Unity][SLM] GPU sampling (#16575)
84b3f69edb is described below

commit 84b3f69edba618d258d51cddd92618538c28ffb4
Author: Yong Wu <yongc...@gmail.com>
AuthorDate: Fri Feb 23 06:07:28 2024 -0800

    [Unity][SLM] GPU sampling (#16575)
    
    This PR adds GPU sampling support to SLM
---
 python/tvm/relax/frontend/nn/_tensor_op.py |  16 +
 python/tvm/relax/frontend/nn/op.py         | 501 +++++++++++++++++++++++++++++
 src/runtime/relax_vm/lm_support.cc         |  37 +++
 tests/python/relax/test_frontend_nn_op.py  | 365 ++++++++++++++++++++-
 tests/python/relax/test_vm_builtin.py      |  57 ++++
 5 files changed, 973 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py 
b/python/tvm/relax/frontend/nn/_tensor_op.py
index 3a646e29b8..7f44ca2438 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -67,6 +67,22 @@ class _TensorOp:
         other = _convert_scalar(other, self)
         return _op().divide(self, other)
 
+    def __lt__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().less(self, other)
+
+    def __le__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().less_equal(self, other)
+
+    def __gt__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().greater(self, other)
+
+    def __ge__(self, other):
+        other = _convert_scalar(other, self)
+        return _op().greater_equal(self, other)
+
     def astype(self, dtype):
         return _op().astype(self, dtype)
 
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index b6c34ca265..6944fc8535 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -24,6 +24,8 @@ from typing import Any, Callable, Dict, List, Optional, 
Sequence, Tuple, TypeVar
 import numpy as np
 
 from tvm import tir as _tir
+from tvm.script import tir as T
+from tvm import te
 
 from ... import expr as rx
 from ... import op as _op
@@ -1825,3 +1827,502 @@ def print_(tensor: Tensor):
     filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
     line_info = f"{filename}:{line_number}"
     debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs < rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs <= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs > rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs >= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs == rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs != rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> 
Tensor:
+    """Selecting elements from either the input tensors depending on the value 
of the
+    condition.
+
+    For a given position, return the corresponding value in `x1` if 
`condition` is True,
+    and return the corresponding value in `x2` otherwise.
+
+    Parameters
+    ----------
+    condition : Tensor
+        When True, yield `x1`; otherwise, yield `x2`.
+        Must be broadcasting compatible with `x1` and `x2`.
+        Must have boolean dtype.
+
+    x1 : Tensor
+        The first input tensor.
+        Must be broadcasting compatible with `condition` and `x2`.
+
+    x2 : Tensor
+        The second input tensor.
+        Must be broadcasting compatible with `condition` and `x1`.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result tensor.
+    """
+    return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)
+
+
+def cumsum(
+    data: Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+    name: str = "cumsum",
+) -> Tensor:
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the 
elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data to the operator.
+
+    axis : Optional[int]
+        Axis along which the cumulative sum is computed. The default (None) is 
to compute
+        the cumsum over the flattened array.
+
+    dtype : Optional[str]
+        Type of the returned array and of the accumulator in which the 
elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : Optional[bool]
+        If true will return exclusive sum in which the first element is not
+        included.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        a = [[1, 2, 3], [4, 5, 6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the 
flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected 
results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str 
= "int64"):
+    """Returns a tensor where each row contains the index sampled from the 
multinomial
+    probability distribution located in the corresponding row of tensor prob.
+
+    Notes
+    -----
+    For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
+    For accurate results, ensure probabilities are between 0 and 1 and sum to 
1.
+
+    Parameters
+    ----------
+    prob : Tensor
+        A 2-D tensor of shape (batch, vocab_size) representing probability 
distributions.
+        Each row is a distribution across vocabulary for a batch, where:
+        Values range from [0, 1], indicating the probability of each 
vocabulary item.
+        The sum of values in each row is 1, forming a valid distribution.
+
+    uniform_sample : Tensor
+        The uniformly sampled 2-D tensor with the shape (batch, 1).
+        Values range from 0 to 1, indicating probabilities sampled uniformly.
+
+    Returns
+    -------
+    result : Tensor
+        The computed tensor with shape (batch, 1).
+
+    Examples
+    --------
+    .. code-block:: python
+
+        prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
+        usample = [[0.4], [0.9]]
+
+        multinomial_from_uniform(prob, usample)
+        -> [[1], [2]]
+    """
+    prob_dtype = prob.dtype
+    sample_dtype = uniform_sample.dtype
+    batch = prob.shape[0]
+
+    @T.prim_func(private=True)
+    def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+        batch, vocab_size = T.int64(), T.int64()
+        prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+        usample = T.match_buffer(B, (batch, 1), sample_dtype)
+        output_index = T.match_buffer(C, (batch, 1), dtype)
+
+        for ax0, ax1 in T.grid(batch, vocab_size):
+            with T.block("T_get_sample_index"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.writes(output_index[v_ax0, 0])
+                if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 
1 == vocab_size:
+                    if v_ax1 == 0:
+                        output_index[v_ax0, 0] = 0
+                    elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
+                        output_index[v_ax0, 0] = v_ax1
+
+    cumsum_prob = cumsum(prob, axis=1, exclusive=False)
+
+    return tensor_ir_op(
+        _get_sample_index,
+        "get_sample_index",
+        args=[cumsum_prob, uniform_sample],
+        out=Tensor.placeholder([batch, 1], dtype),
+    )
+
+
+def sample_top_p_top_k_from_sorted_prob(
+    sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, 
uniform_sample: Tensor
+):
+    """Samples indices from a sorted probability tensor based on top_p and 
top_k criteria.
+
+    Notes
+    -----
+    For accurate results, ensure probabilities are between 0 and 1 and sum to 
1.
+
+    Parameters
+    ----------
+    sorted_prob : Tensor
+        A 2-D tensor, with shape (batch, vocab_size), contains probabilities
+        sorted in descending order.
+
+    sorted_index: Tensor
+        The indices tensor with shape (batch, vocab_size), corresponding to the
+        sorted_prob. Potentially from applying argsort on the original 
probability
+        tensor in descending order.
+
+    top_p : Tensor
+        The cumulative probability threshold with shape (batch, 1) for nucleus 
sampling.
+
+    top_k :Tensor
+        A tensor with shape (batch, 1), representing the number of top 
probabilities
+        to consider for top-k sampling.
+
+    uniform_sample : Tensor
+        Uniformly sampled values with shape (batch, 1) are used to select the 
output indices.
+
+    Returns
+    -------
+    result : Tensor
+        The selected indices with shape (batch, 1).
+
+    Examples
+    --------
+    .. code-block:: python
+
+        prob = [[0.1 , 0.4, 0.5],
+                [0.3, 0.3, 0.4]]
+        sorted_prob = [[0.5, 0.4, 0.1],
+                       [0.4, 0.3, 0.3]]
+        sorted_index = [[2, 1, 0],
+                        [2, 0, 1]]
+        top_p = [[0.6],[0.9]]
+        top_k = [[3],[2]]
+        uniform_sample = [[0.5], [0.6]]
+
+        sample_top_p_top_k_from_sorted_prob(
+            sorted_prob, sorted_index,top_p, top_k, uniform_sample)
+        -> [2, 0]
+
+    """
+    prob_dtype = sorted_prob.dtype
+    index_dtype = sorted_index.dtype
+    batch = sorted_prob.shape[0]
+
+    def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
+        return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
+
+    @T.prim_func(private=True)
+    def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
+        batch, vocab_size = T.int64(), T.int64()
+        cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+        top_p = T.match_buffer(B, (batch, 1), prob_dtype)
+        top_k = T.match_buffer(C, (batch, 1), index_dtype)
+        renorm_prob = T.match_buffer(D, (batch, 1), prob_dtype)
+        for ax0, ax1 in T.grid(batch, vocab_size):
+            with T.block("T_get_renorm_prob"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
+                    renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
+                elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1) 
== 1:
+                    if v_ax1 + 1 == vocab_size:
+                        renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
+                    elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 
v_ax1 + 1) == 0:
+                        renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]
+
+    @T.prim_func(private=True)
+    def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle, E: T.handle):
+        batch, vocab_size = T.int64(), T.int64()
+        cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+        renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
+        usample = T.match_buffer(C, (batch, 1), prob_dtype)
+        indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
+        output_index = T.match_buffer(E, (batch, 1), index_dtype)
+
+        for ax0, ax1 in T.grid(batch, vocab_size):
+            with T.block("T_get_index_from_sorted"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.writes(output_index[v_ax0, 0])
+                if (
+                    usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / 
renorm_prob[v_ax0, 0]
+                    or v_ax1 + 1 == vocab_size
+                ):
+                    if v_ax1 == 0:
+                        output_index[v_ax0, 0] = indices[v_ax0, 0]
+                    elif (
+                        usample[v_ax0, T.int64(0)]
+                        >= cumsum_sorted[v_ax0, v_ax1 - 1] / 
renorm_prob[v_ax0, 0]
+                    ):
+                        output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
+
+    cumsum_sorted = cumsum(sorted_prob, axis=1)
+
+    renorm_prob = tensor_ir_op(
+        _get_renorm_prob,
+        "get_renorm_prob",
+        args=[cumsum_sorted, top_p, top_k],
+        out=Tensor.placeholder(
+            [batch, 1],
+            prob_dtype,
+        ),
+    )
+
+    out_index_in_sorted = tensor_ir_op(
+        _get_index_from_sorted,
+        "get_index_from_sorted",
+        args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
+        out=Tensor.placeholder([batch, 1], index_dtype),
+    )
+    return out_index_in_sorted
+
+
+def renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k):
+    """Renormalizes probabilities after filtering with top_p and top_k, 
ensuring
+    they sum up to 1.
+
+    Notes
+    -----
+    For accurate results, ensure probabilities are between 0 and 1 and sum to 
1.
+
+    Parameters
+    ----------
+    prob : Tensor
+        A 2-D tensor of shape (batch, vocab_size) representing probability 
distributions.
+
+    sorted_prob : Tensor
+        Probabilities sorted in descending order.
+
+    top_p : Tensor
+        The cumulative probability threshold with shape (batch, 1) for nucleus 
sampling.
+
+    top_k :Tensor
+        A tensor with shape (batch, 1), representing the number of top 
probabilities
+        to consider for top-k sampling.
+
+    Returns
+    -------
+    result : Tensor
+        The filtered and nomalized tensor with the sampe shape as input prob.
+    """
+    prob_dtype = prob.dtype
+    top_k_dtype = top_k.dtype
+    batch = sorted_prob.shape[0]
+
+    def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
+        return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
+
+    @T.prim_func(private=True)
+    def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, 
E: T.handle):
+        batch, vocab_size = T.int64(), T.int64()
+        sorted_prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+        cumsum_sorted = T.match_buffer(B, (batch, vocab_size), prob_dtype)
+        top_p = T.match_buffer(C, (batch, 1), prob_dtype)
+        top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
+        cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
+        for ax0, ax1 in T.grid(batch, vocab_size):
+            with T.block("T_get_renorm_prob"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
+                    cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
+                elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1) 
== 1:
+                    if v_ax1 + 1 == vocab_size:
+                        cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]
+                    elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 
v_ax1 + 1) == 0:
+                        cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + 1]
+
+    cumsum_sorted = cumsum(sorted_prob, axis=1)
+
+    renorm_cutoff = tensor_ir_op(
+        _get_renorm_cutoff,
+        "get_renorm_cutoff",
+        args=[sorted_prob, cumsum_sorted, top_p, top_k],
+        out=Tensor.placeholder(
+            [batch, 1],
+            prob_dtype,
+        ),
+    )
+
+    filtered_prob = tensor_expr_op(
+        lambda prob, renorm_cutoff: te.compute(
+            prob.shape,
+            lambda i, j: _tir.Select(prob[i, j] >= renorm_cutoff[i, 0], 
prob[i, j], 0.0),
+            name="filter_with_top_p_top_k",
+        ),
+        "filter_with_top_p_top_k",
+        args=[prob, renorm_cutoff],
+    )
+    renorm_prob = filtered_prob / sum(filtered_prob, axis=1, keepdims=True)
+    return renorm_prob
diff --git a/src/runtime/relax_vm/lm_support.cc 
b/src/runtime/relax_vm/lm_support.cc
index cfb78006d7..95dca0c6d5 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -496,6 +496,43 @@ int SampleTopPFromProb(NDArray prob, double top_p, double 
uniform_sample) {
 
 
TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);
 
+NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) {
+  ICHECK(prob.IsContiguous());
+  ICHECK(uniform_sample.IsContiguous());
+
+  if (prob->device.device_type != kDLCPU) {
+    prob = prob.CopyTo(DLDevice{kDLCPU, 0});
+  }
+  if (uniform_sample->device.device_type != kDLCPU) {
+    uniform_sample = uniform_sample.CopyTo(DLDevice{kDLCPU, 0});
+  }
+
+  ICHECK(prob->device.device_type == kDLCPU);
+  ICHECK(uniform_sample->device.device_type == kDLCPU);
+
+  int64_t batch_size = prob->shape[0];
+  int64_t vocab_size = prob->shape[prob->ndim - 1];
+  const float* pprob = static_cast<float*>(prob->data);
+  const float* psample = static_cast<float*>(uniform_sample->data);
+  NDArray new_array = NDArray::Empty({batch_size, 1}, DataType::Int(64), 
uniform_sample->device);
+  int64_t* parray = static_cast<int64_t*>(new_array->data);
+  for (int64_t i = 0; i < batch_size; ++i) {
+    float cum_sum_prob = 0.0f;
+    int64_t prob_idx = 0;
+    for (int64_t j = 0; j < vocab_size; ++j) {
+      prob_idx = j;
+      cum_sum_prob += pprob[i * vocab_size + j];
+      if (cum_sum_prob > psample[i]) {
+        break;
+      }
+    }
+    parray[i] = prob_idx;
+  }
+  return new_array;
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform").set_body_typed(MultinomialFromUniform);
+
 // This is an inplace operation.
 void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) 
{
   ICHECK(logits.IsContiguous());
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 650d8ace30..3457989a55 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring, invalid-name
+import numpy as np
 import tvm
 import tvm.testing
 from tvm import relax, tir
@@ -61,11 +62,18 @@ def test_binary():
             z4 = op.maximum(x, y)
             z5 = op.minimum(x, y)
             z6 = op.subtract(x, y)
-            return (z0, z1, z2, z3, z4, z5, z6)
+            z7 = op.greater(x, y)
+            z8 = op.greater_equal(x, y)
+            z9 = op.less(x, y)
+            z10 = op.less_equal(x, y)
+            z11 = op.equal(x, y)
+            z12 = op.not_equal(x, y)
+
+            return (z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12)
 
     # fmt: off
     @R.function
-    def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), 
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10), 
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), 
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32")), R.Tuple(R.Object)):
+    def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), 
dtype="float32"), _io: R.Object):
         R.func_attr({"num_input": 3})
         with R.dataflow():
             add: R.Tensor((10, 10), dtype="float32") = R.add(x, y)
@@ -75,7 +83,13 @@ def test_binary():
             maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y)
             minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y)
             subtract: R.Tensor((10, 10), dtype="float32") = R.subtract(x, y)
-            gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"), 
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), 
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"), 
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")), 
R.Tuple(R.Object)) = (add, mul, divide, matmul, maximum, minimum, subtract), 
(_io,)
+            greater: R.Tensor((10, 10), dtype="bool") = x > y
+            greater_equal: R.Tensor((10, 10), dtype="bool") = x >= y
+            less: R.Tensor((10, 10), dtype="bool") = x < y
+            less_equal: R.Tensor((10, 10), dtype="bool") = x <= y
+            equal: R.Tensor((10, 10), dtype="bool") = R.equal(x, y)
+            not_equal: R.Tensor((10, 10), dtype="bool") = R.not_equal(x, y)
+            gv1 = (add, mul, divide, matmul, maximum, minimum, subtract, 
greater, greater_equal, less, less_equal, equal, not_equal), (_io,)
             R.output(gv1)
         return gv1
     # fmt: on
@@ -829,5 +843,350 @@ def test_empty():
     vm["test"](*effects)
 
 
+@tvm.testing.requires_gpu
+def test_multinomial_from_uniform():
+
+    prob_shape = (4, 5)
+    sample_shape = (4, 1)
+
+    class Model(Module):
+        def foo(self, prob: Tensor, uniform_sample: Tensor):
+            z0 = op.multinomial_from_uniform(prob, uniform_sample)
+            return z0
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+            batch, vocab_size = T.int64(), T.int64()
+            prob = T.match_buffer(A, (batch, vocab_size))
+            usample = T.match_buffer(B, (batch, 1))
+            output_index = T.match_buffer(C, (batch, 1), "int64")
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(batch, vocab_size):
+                with T.block("T_get_sample_index"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(usample[v_ax0, T.int64(0)], prob[v_ax0, v_ax1 - 
T.int64(1):v_ax1 - T.int64(1) + T.int64(2)])
+                    T.writes(output_index[v_ax0, 0])
+                    if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or 
v_ax1 + T.int64(1) == vocab_size:
+                        if v_ax1 == T.int64(0):
+                            output_index[v_ax0, 0] = T.int64(0)
+                        else:
+                            if usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 
- T.int64(1)]:
+                                output_index[v_ax0, 0] = v_ax1
+
+        @R.function
+        def _initialize_effect() -> R.Tuple(R.Object):
+            with R.dataflow():
+                _io: R.Object = R.null_value()
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
+                R.output(gv)
+            return gv
+
+        @R.function
+        def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample: 
R.Tensor((4, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((4, 1), 
dtype="int64"), R.Tuple(R.Object)):
+            R.func_attr({"num_input": 3})
+            cls = Expected
+            with R.dataflow():
+                cumsum: R.Tensor((4, 5), dtype="float32") = R.cumsum(prob, 
axis=1, dtype="void", exclusive=False)
+                lv1 = R.call_tir(cls.get_sample_index, (cumsum, 
uniform_sample), out_sinfo=R.Tensor((4, 1), dtype="int64"))
+                gv1: R.Tuple(R.Tensor((4, 1), dtype="int64"), 
R.Tuple(R.Object)) = lv1, (_io,)
+                R.output(gv1)
+            return gv1
+    # fmt: on
+
+    m = Model()
+    mod, _ = m.export_tvm(
+        spec={
+            "foo": {
+                "prob": spec.Tensor(prob_shape, "float32"),
+                "uniform_sample": spec.Tensor(sample_shape, "float32"),
+            }
+        },
+        debug=True,
+    )
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+    target = tvm.target.Target("cuda -libs=thrust", host="llvm")
+    with target:
+        mod = tir.transform.DefaultGPUSchedule()(mod)
+    ex = relax.build(mod, target)
+    dev = tvm.cuda(0)
+    vm = relax.VirtualMachine(ex, dev)
+
+    effects = vm["_initialize_effect"]()
+
+    np_rand = np.random.rand(*prob_shape).astype(np.float32)
+    # normalize it to get the random prob
+    np_prob = np_rand / np_rand.sum(axis=1, keepdims=True)
+    nd_prob = tvm.nd.array(np_prob, dev)
+    # special sample to get deterministic results
+    nd_sample = tvm.nd.array(np.array([[1], [0], [0], 
[1]]).astype(np.float32), dev)
+    inputs = [nd_prob, nd_sample, effects]
+    res = vm["foo"](*inputs)
+    tvm.testing.assert_allclose(res[0].numpy(), np.array([[4], [0], [0], 
[4]]).astype(np.int64))
+
+
+@tvm.testing.requires_gpu
+def test_sample_top_p_top_k_from_sorted_prob():
+    prob_shape = (2, 3)
+    sample_shape = (2, 1)
+
+    class Model(Module):
+        def foo(
+            self, prob: Tensor, index: Tensor, top_p: Tensor, top_k: Tensor, 
uniform_sample: Tensor
+        ):
+            z0 = op.sample_top_p_top_k_from_sorted_prob(prob, index, top_p, 
top_k, uniform_sample)
+            return z0
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle, E: T.handle):
+            batch, vocab_size = T.int64(), T.int64()
+            cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
+            renorm_prob = T.match_buffer(B, (batch, 1))
+            usample = T.match_buffer(C, (batch, 1))
+            indices = T.match_buffer(D, (batch, vocab_size), "int64")
+            output_index = T.match_buffer(E, (batch, 1), "int64")
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(batch, vocab_size):
+                with T.block("T_get_index_from_sorted"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(
+                        usample[v_ax0, T.int64(0)],
+                        cumsum_sorted[v_ax0, v_ax1 - T.int64(1) : v_ax1 - 
T.int64(1) + T.int64(2)],
+                        renorm_prob[v_ax0, 0],
+                        indices[
+                            v_ax0,
+                            T.min(T.int64(0), v_ax1) : T.min(T.int64(0), v_ax1)
+                            + (T.max(T.int64(0), v_ax1) + T.int64(1) - 
T.min(T.int64(0), v_ax1)),
+                        ],
+                    )
+                    T.writes(output_index[v_ax0, 0])
+                    if (
+                        usample[v_ax0, T.int64(0)]
+                        < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
+                        or v_ax1 + T.int64(1) == vocab_size
+                    ):
+                        if v_ax1 == T.int64(0):
+                            output_index[v_ax0, 0] = indices[v_ax0, 0]
+                        else:
+                            if (
+                                usample[v_ax0, T.int64(0)]
+                                >= cumsum_sorted[v_ax0, v_ax1 - T.int64(1)] / 
renorm_prob[v_ax0, 0]
+                            ):
+                                output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
+
+        @T.prim_func(private=True)
+        def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle):
+            batch, vocab_size = T.int64(), T.int64()
+            cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
+            top_p = T.match_buffer(B, (batch, 1))
+            top_k = T.match_buffer(C, (batch, 1), "int64")
+            renorm_prob = T.match_buffer(D, (batch, 1))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(batch, vocab_size):
+                with T.block("T_get_renorm_prob"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), 
v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) 
+ (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - 
T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], 
top_k[v_ax0, 0])
+                    T.writes(renorm_prob[v_ax0, 0])
+                    if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and 
top_k[v_ax0, 0] > T.int64(1)) == T.bool(False):
+                        renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
+                    else:
+                        if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and 
v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True):
+                            if v_ax1 + T.int64(1) == vocab_size:
+                                renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 
v_ax1]
+                            else:
+                                if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < 
top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) == 
T.bool(False):
+                                    renorm_prob[v_ax0, 0] = 
cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]
+
+        @R.function
+        def _initialize_effect() -> R.Tuple(R.Object):
+            with R.dataflow():
+                _io: R.Object = R.null_value()
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
+                R.output(gv)
+            return gv
+
+        @R.function
+        def foo(
+            prob: R.Tensor((2, 3), dtype="float32"),
+            index: R.Tensor((2, 3), dtype="int64"),
+            top_p: R.Tensor((2, 1), dtype="float32"),
+            top_k: R.Tensor((2, 1), dtype="int64"),
+            uniform_sample: R.Tensor((2, 1), dtype="float32"),
+            _io: R.Object,
+        ) -> R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)):
+            R.func_attr({"num_input": 6})
+            cls = Expected
+            with R.dataflow():
+                cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob, 
axis=1, dtype="void", exclusive=None)
+                lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), 
out_sinfo=R.Tensor((2, 1), dtype="float32"))
+                lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, lv1, 
uniform_sample, index), out_sinfo=R.Tensor((2, 1), dtype="int64"))
+                gv1: R.Tuple(R.Tensor((2, 1), dtype="int64"), 
R.Tuple(R.Object)) = lv2, (_io,)
+                R.output(gv1)
+            return gv1
+    # fmt: on
+
+    m = Model()
+    mod, _ = m.export_tvm(
+        spec={
+            "foo": {
+                "prob": spec.Tensor(prob_shape, "float32"),
+                "index": spec.Tensor(prob_shape, "int64"),
+                "top_p": spec.Tensor(sample_shape, "float32"),
+                "top_k": spec.Tensor(sample_shape, "int64"),
+                "uniform_sample": spec.Tensor(sample_shape, "float32"),
+            }
+        },
+        debug=True,
+    )
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+    target = tvm.target.Target("cuda -libs=thrust", host="llvm")
+    with target:
+        mod = tir.transform.DefaultGPUSchedule()(mod)
+
+    ex = relax.build(mod, target)
+    dev = tvm.cuda(0)
+    vm = relax.VirtualMachine(ex, dev)
+
+    effects = vm["_initialize_effect"]()
+    sorted_prob = tvm.nd.array(np.array([[0.5, 0.4, 0.1], [0.4, 0.3, 
0.3]]).astype(np.float32), dev)
+    indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), 
dev)
+    top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev)
+    top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev)
+    usample = tvm.nd.array(np.array([[0.5], [0.6]]).astype(np.float32), dev)
+
+    inputs = [sorted_prob, indices, top_p, top_k, usample, effects]
+
+    res = vm["foo"](*inputs)
+    tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], 
[0]]).astype(np.int64))
+
+
+@tvm.testing.requires_gpu
+def test_renormalize_top_p_top_k_prob():
+    prob_shape = (2, 3)
+    sample_shape = (2, 1)
+
+    class Model(Module):
+        def foo(
+            self,
+            prob: Tensor,
+            sorted_prob: Tensor,
+            top_p: Tensor,
+            top_k: Tensor,
+        ):
+            z0 = op.renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, 
top_k)
+            return z0
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def filter_with_top_p_top_k(A: T.Buffer((T.int64(2), T.int64(3)), 
"float32"), B: T.Buffer((T.int64(2), T.int64(1)), "float32"), 
filter_with_top_p_top_k: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(2), T.int64(3)):
+                with T.block("filter_with_top_p_top_k"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(B[v_i, T.int64(0)], A[v_i, v_j])
+                    T.writes(filter_with_top_p_top_k[v_i, v_j])
+                    filter_with_top_p_top_k[v_i, v_j] = T.Select(B[v_i, 
T.int64(0)] <= A[v_i, v_j], A[v_i, v_j], T.float32(0))
+
+        @T.prim_func(private=True)
+        def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: 
T.handle, E: T.handle):
+            batch, vocab_size = T.int64(), T.int64()
+            sorted_prob = T.match_buffer(A, (batch, vocab_size))
+            cumsum_sorted = T.match_buffer(B, (batch, vocab_size))
+            top_p = T.match_buffer(C, (batch, 1))
+            top_k = T.match_buffer(D, (batch, 1), "int64")
+            cutoff = T.match_buffer(E, (batch, 1))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(batch, vocab_size):
+                with T.block("T_get_renorm_prob"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), 
v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) 
+ (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - 
T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], 
top_k[v_ax0, 0], sorted_prob[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + 
T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + 
(T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) +  [...]
+                    T.writes(cutoff[v_ax0, 0])
+                    if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and 
top_k[v_ax0, 0] > T.int64(1)) == T.bool(False):
+                        cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
+                    else:
+                        if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and 
v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True):
+                            if v_ax1 + T.int64(1) == vocab_size:
+                                cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]
+                            else:
+                                if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < 
top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) == 
T.bool(False):
+                                    cutoff[v_ax0, 0] = sorted_prob[v_ax0, 
v_ax1 + T.int64(1)]
+
+        @R.function
+        def _initialize_effect() -> R.Tuple(R.Object):
+            with R.dataflow():
+                _io: R.Object = R.null_value()
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
+                R.output(gv)
+            return gv
+
+        @R.function
+        def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: 
R.Tensor((2, 3), dtype="float32"), top_p: R.Tensor((2, 1), dtype="float32"), 
top_k: R.Tensor((2, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((2, 
3), dtype="float32"), R.Tuple(R.Object)):
+            R.func_attr({"num_input": 5})
+            cls = Expected
+            with R.dataflow():
+                cumsum: R.Tensor((2, 3), dtype="float32") = 
R.cumsum(sorted_prob, axis=1, dtype="void", exclusive=None)
+                lv1 = R.call_tir(cls.get_renorm_cutoff, (sorted_prob, cumsum, 
top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32"))
+                lv2 = R.call_tir(cls.filter_with_top_p_top_k, (prob, lv1), 
out_sinfo=R.Tensor((2, 3), dtype="float32"))
+                sum: R.Tensor((2, 1), dtype="float32") = R.sum(lv2, axis=[1], 
keepdims=True)
+                divide: R.Tensor((2, 3), dtype="float32") = R.divide(lv2, sum)
+                gv1: R.Tuple(R.Tensor((2, 3), dtype="float32"), 
R.Tuple(R.Object)) = divide, (_io,)
+                R.output(gv1)
+            return gv1
+    # fmt: on
+
+    m = Model()
+    mod, _ = m.export_tvm(
+        spec={
+            "foo": {
+                "prob": spec.Tensor(prob_shape, "float32"),
+                "sorted_prob": spec.Tensor(prob_shape, "float32"),
+                "top_p": spec.Tensor(sample_shape, "float32"),
+                "top_k": spec.Tensor(sample_shape, "int64"),
+            }
+        },
+        debug=True,
+    )
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+    target = tvm.target.Target("cuda -libs=thrust", host="llvm")
+    with target:
+        mod = relax.backend.DispatchSortScan()(mod)
+        mod = relax.transform.LegalizeOps()(mod)
+        mod = tir.transform.DefaultGPUSchedule()(mod)
+
+    ex = relax.build(mod, target)
+    dev = tvm.cuda(0)
+    vm = relax.VirtualMachine(ex, dev)
+
+    effects = vm["_initialize_effect"]()
+    prob = tvm.nd.array(np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 
0.4]]).astype(np.float32), dev)
+    sorted_prob = tvm.nd.array(np.array([[0.5, 0.3, 0.2], [0.4, 0.3, 
0.3]]).astype(np.float32), dev)
+    top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev)
+    top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev)
+
+    inputs = [prob, sorted_prob, top_p, top_k, effects]
+
+    res = vm["foo"](*inputs)
+    tvm.testing.assert_allclose(
+        res[0].numpy(), np.array([[0, 0.375, 0.625], [0.3, 0.3, 
0.4]]).astype(np.float32)
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_vm_builtin.py 
b/tests/python/relax/test_vm_builtin.py
new file mode 100644
index 0000000000..f786f707af
--- /dev/null
+++ b/tests/python/relax/test_vm_builtin.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+
+
+def test_multinomial_from_uniform():
+    @tvm.script.ir_module
+    class CallSample:
+        @R.function
+        def foo(x: R.Tensor((3, 5), "float32"), y: R.Tensor((3, 1), 
"float32")):
+            z = R.call_pure_packed(
+                "vm.builtin.multinomial_from_uniform",
+                x,
+                y,
+                sinfo_args=(R.Tensor((3, 1), dtype="int64")),
+            )
+            return z
+
+    mod = CallSample
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.build(mod, target)
+    np_rand = np.random.rand(3, 5).astype(np.float32)
+    # normalize it to get the random prob
+    np_prob = np_rand / np_rand.sum(axis=1, keepdims=True)
+    nd_prob = tvm.nd.array(np_prob)
+    # special sample to get deterministic results
+    nd_sample = tvm.nd.array(np.array([[1.0], [0], [1]]).astype(np.float32))
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    res = vm["foo"](nd_prob, nd_sample)
+    tvm.testing.assert_allclose(res.numpy(), np.array([[4], [0], 
[4]]).astype(np.int64))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to