shingjan commented on code in PR #12124:
URL: https://github.com/apache/tvm/pull/12124#discussion_r930499403


##########
python/tvm/relay/op/transform.py:
##########
@@ -1889,3 +1889,46 @@ def stft(
         window = _make.ones([n_fft], "int32")
 
     return _make.stft(data, n_fft, hop_length, win_length, window, normalized, 
onesided)
+
+
+def trilu(data, k, upper=True):
+    """
+    Given a 2-D matrix or batches of 2-D matrices, returns the
+    upper or lower triangular part of the tensor.
+
+    Parameters
+    ----------
+    data: relay.Expr
+        The tensor that trilu will be applied to. Must be either
+        a 2D matrix or a tensor of batches of 2D matrices.
+
+    k: int
+        The number of diagonals above or below the main diagonal
+        to exclude or include.
+
+    upper: bool, optional
+        If True, only upper triangular values of input are kept,
+        if False, the lower triangular values are kept.
+
+
+    Returns
+    -------
+    ret : relay.Expr
+        The new tensor with appropriate diagonals set to zero.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = [[0, 1, 2],
+             [3, 4, 5],
+             [6, 7, 8]]
+
+        relay.trilu(x, True, 0) =
+            [[0, 1, 2],
+             [0, 4, 5],
+             [0, 0, 8]]
+    """
+    if not isinstance(k, Expr):

Review Comment:
   Nit: Should we check if k is `int`, do `const(k, "int32")` and throw an 
error otherwise? The reason I am asking is bcoz in the triu/tril op of pytorch 
`k` is actually an int instead of 0D tensor. Could cover the case of wrongly 
typed user input.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to