tqchen commented on code in PR #16575:
URL: https://github.com/apache/tvm/pull/16575#discussion_r1500129673
##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1825,3 +1827,509 @@ def print_(tensor: Tensor):
filename, line_number =
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
line_info = f"{filename}:{line_number}"
debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs < rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs <= rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs > rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs >= rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs == rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+ """Broadcasted element-wise comparison for (lhs != rhs).
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") ->
Tensor:
+ """Selecting elements from either the input tensors depending on the value
of the
+ condition.
+
+ For a given position, return the corresponding value in `x1` if
`condition` is True,
+ and return the corresponding value in `x2` otherwise.
+
+ Parameters
+ ----------
+ condition : Tensor
+ When True, yield `x1`; otherwise, yield `x2`.
+ Must be broadcasting compatible with `x1` and `x2`.
+ Must have boolean dtype.
+
+ x1 : Tensor
+ The first input tensor.
+ Must be broadcasting compatible with `condition` and `x2`.
+
+ x2 : Tensor
+ The second input tensor.
+ Must be broadcasting compatible with `condition` and `x1`.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+ return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)
+
+
+def cumsum(
+ data: Tensor,
+ axis: Optional[int] = None,
+ dtype: Optional[str] = None,
+ exclusive: Optional[bool] = None,
+ name: str = "cumsum",
+) -> Tensor:
+ """Numpy style cumsum op. Return the cumulative inclusive sum of the
elements along
+ a given axis.
+
+ Parameters
+ ----------
+ data : Tensor
+ The input data to the operator.
+
+ axis : Optional[int]
+ Axis along which the cumulative sum is computed. The default (None) is
to compute
+ the cumsum over the flattened array.
+
+ dtype : Optional[str]
+ Type of the returned array and of the accumulator in which the
elements are summed.
+ If dtype is not specified, it defaults to the dtype of data.
+
+ exclusive : Optional[bool]
+ If true will return exclusive sum in which the first element is not
+ included.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result has the same size as data, and the same shape as data if
axis is not None.
+ If axis is None, the result is a 1-d array.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ a = [[1, 2, 3], [4, 5, 6]]
+
+ cumsum(a) # if axis is not provided, cumsum is done over the
flattened input.
+ -> [ 1, 3, 6, 10, 15, 21]
+
+ cumsum(a, dtype="float32")
+ -> [ 1., 3., 6., 10., 15., 21.]
+
+ cumsum(a, axis=0) # sum over rows for each of the 3 columns
+ -> [[1, 2, 3],
+ [5, 7, 9]]
+
+ cumsum(a, axis=1)
+ -> [[ 1, 3, 6],
+ [ 4, 9, 15]]
+
+ a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
+ cumsum(a, dtype=int32) # dtype should be provided to get the expected
results
+ -> [1, 1, 2, 2, 3, 4, 4]
+ """
+ return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str
= "int64"):
+ """Returns a tensor where each row contains the index sampled from the
multinomial
+ probability distribution located in the corresponding row of tensor prob.
+
+ Notes
+ -----
+ For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
+ For accurate results, ensure probabilities are between 0 and 1 and sum to
1.
+
+ Parameters
+ ----------
+ prob : Tensor
+ A 2-D tensor of shape (batch, vocab_size) representing probability
distributions.
+ Each row is a distribution across vocabulary for a batch, where:
+ Values range from [0, 1], indicating the probability of each
vocabulary item.
+ The sum of values in each row is 1, forming a valid distribution.
+
+ uniform_sample : Tensor
+ The uniformly sampled 2-D tensor with the shape (batch, 1).
+ Values range from 0 to 1, indicating probabilities sampled uniformly.
+
+ Returns
+ -------
+ result : Tensor
+ The computed tensor with shape (batch, 1).
+
+ Examples
+ --------
+ .. code-block:: python
+
+ prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
+ usample = [[0.4], [0.9]]
+
+ multinomial_from_uniform(prob, usample)
+ -> [[1], [2]]
+ """
+ prob_dtype = prob.dtype
+ sample_dtype = uniform_sample.dtype
+ batch = prob.shape[0]
+ cumsum_prob = cumsum(prob, axis=1, exclusive=False)
+
+ @T.prim_func(private=True)
+ def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ usample = T.match_buffer(B, (batch, 1), sample_dtype)
+ output_index = T.match_buffer(C, (batch, 1), dtype)
+
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_sample_index"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.writes(output_index[v_ax0, 0])
+ if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 +
1 == vocab_size:
+ if v_ax1 == 0:
+ output_index[v_ax0, 0] = 0
+ elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
+ output_index[v_ax0, 0] = v_ax1
+
+ return tensor_ir_op(
+ _get_sample_index,
+ "get_sample_index",
+ args=[cumsum_prob, uniform_sample],
+ out=Tensor.placeholder([batch, 1], dtype),
+ )
+
+
+def sample_top_p_top_k_from_sorted_prob(
+ sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor,
uniform_sample: Tensor
+):
+ """Samples indices from a sorted probability tensor based on top_p and
top_k criteria.
+
+ Notes
+ -----
+ For accurate results, ensure probabilities are between 0 and 1 and sum to
1.
+
+ Parameters
+ ----------
+ sorted_prob : Tensor
+ A 2-D tensor, with shape (batch, vocab_size), contains probabilities
+ sorted in descending order.
+
+ sorted_index: Tensor
+ The indices tensor with shape (batch, vocab_size), corresponding to the
+ sorted_prob. Potentially from applying argsort on the original
probability
+ tensor in descending order.
+
+ top_p : Tensor
+ The cumulative probability threshold with shape (batch, 1) for nucleus
sampling.
+
+ top_k :Tensor
+ A tensor with shape (batch, 1), representing the number of top
probabilities
+ to consider for top-k sampling.
+
+ uniform_sample : Tensor
+ Uniformly sampled values with shape (batch, 1) are used to select the
output indices.
+
+ Returns
+ -------
+ result : Tensor
+ The selected indices with shape (batch, 1).
+
+ Examples
+ --------
+ .. code-block:: python
+
+ prob = [[0.1 , 0.4, 0.5],
+ [0.3, 0.3, 0.4]]
+ sorted_prob = [[0.5, 0.4, 0.1],
+ [0.4, 0.3, 0.3]]
+ sorted_index = [[2, 1, 0],
+ [2, 0, 1]]
+ top_p = [[0.6],[0.9]]
+ top_k = [[3],[2]]
+ uniform_sample = [[0.5], [0.6]]
+
+ sample_top_p_top_k_from_sorted_prob(
+ sorted_prob, sorted_index,top_p, top_k, uniform_sample)
+ -> [2, 0]
+
+ """
+ prob_dtype = sorted_prob.dtype
+ index_dtype = sorted_index.dtype
+ batch = sorted_prob.shape[0]
+
+ @T.prim_func(private=True)
+ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ cumsum_prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ cumsum_mask = T.match_buffer(B, (batch, vocab_size), "bool")
+ renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_renorm_prob"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ if cumsum_mask[v_ax0, 0] == 0:
+ renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, 0]
+ elif cumsum_mask[v_ax0, v_ax1] == 1 and cumsum_mask[v_ax0,
v_ax1 + 1] == 0:
+ renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, v_ax1 + 1]
+ elif cumsum_mask[v_ax0, v_ax1] == 1 and v_ax1 + 1 ==
vocab_size:
+ renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, v_ax1]
+
+ @T.prim_func(private=True)
+ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D:
T.handle, E: T.handle):
+ batch, vocab_size = T.int64(), T.int64()
+ cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
+ renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
+ usample = T.match_buffer(C, (batch, 1), prob_dtype)
+ indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
+ output_index = T.match_buffer(E, (batch, 1), index_dtype)
+
+ for ax0, ax1 in T.grid(batch, vocab_size):
+ with T.block("T_get_index_from_sorted"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.writes(output_index[v_ax0, 0])
+ if (
+ usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] /
renorm_prob[v_ax0, 0]
+ or v_ax1 + 1 == vocab_size
+ ):
+ if v_ax1 == 0:
+ output_index[v_ax0, 0] = indices[v_ax0, 0]
+ elif (
+ usample[v_ax0, T.int64(0)]
+ >= cumsum_sorted[v_ax0, v_ax1 - 1] /
renorm_prob[v_ax0, 0]
+ ):
+ output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
+
+ cumsum_sorted = cumsum(sorted_prob, axis=1)
+
+ cumsum_mask = tensor_expr_op(
+ lambda cumsum_sorted, top_p, top_k: te.compute(
+ cumsum_sorted.shape,
+ lambda i, j: _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 <
top_k[i, 0]),
+ name="get_cumsum_mask_top_p_top_k",
+ ),
+ "get_cumsum_mask_top_p_top_k",
+ args=[cumsum_sorted, top_p, top_k],
+ )
+
+ renorm_prob = tensor_ir_op(
+ _get_renorm_prob,
+ "get_renorm_prob",
+ args=[cumsum_sorted, cumsum_mask],
+ out=Tensor.placeholder(
+ [batch, 1],
+ prob_dtype,
+ ),
+ )
+
+ out_index_in_sorted = tensor_ir_op(
+ _get_index_from_sorted,
+ "get_index_from_sorted",
+ args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
+ out=Tensor.placeholder([batch, 1], index_dtype),
+ )
+ return out_index_in_sorted
+
+
+def renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k):
+ """Renormalizes probabilities after filtering with top_p and top_k,
ensuring
+ they sum up to 1.
+
+ Notes
+ -----
+ For accurate results, ensure probabilities are between 0 and 1 and sum to
1.
+
+ Parameters
+ ----------
+ prob : Tensor
+ A 2-D tensor of shape (batch, vocab_size) representing probability
distributions.
+
+ sorted_prob : Tensor
+ Probabilities sorted in descending order.
+
+ top_p : Tensor
+ The cumulative probability threshold with shape (batch, 1) for nucleus
sampling.
+
+ top_k :Tensor
+ A tensor with shape (batch, 1), representing the number of top
probabilities
+ to consider for top-k sampling.
+
+ Returns
+ -------
+ result : Tensor
+ The filtered and nomalized tensor with the sampe shape as input prob.
+ """
+ dtype = prob.dtype
+ batch = sorted_prob.shape[0]
+
+ @T.prim_func(private=True)
+ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle):
Review Comment:
make cumsum_mask a function that takes in axis, then inline it via
metaprogramming
```python
def cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]